Compare commits

...

56 Commits

Author SHA1 Message Date
Luke Mino-Altherr
043a75acde Fix ruff linting issues
- Remove debug print statements
- Remove trailing whitespace on blank lines
- Remove unused pytest import

Amp-Thread-ID: https://ampcode.com/threads/T-019c3a8d-3b4f-75b4-8513-1c77914782f7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-10 13:01:25 -08:00
Luke Mino-Altherr
dcba47251a Skip hidden files and directories in asset scanner
Amp-Thread-ID: https://ampcode.com/threads/T-019c3a75-046e-758d-ac96-08d45281a0c8
Co-authored-by: Amp <amp@ampcode.com>
2026-02-07 15:38:24 -08:00
Luke Mino-Altherr
8f7362d8b0 Populate mime_type for assets in scanner and API paths
- Add custom MIME type registrations for model files (.safetensors, .pt, .ckpt, .gguf, .yaml)
- Pass mime_type through SeedAssetSpec to bulk_ingest
- Re-register types before use since server.py mimetypes.init() resets them
- Add tests for bulk ingest mime_type handling

Amp-Thread-ID: https://ampcode.com/threads/T-019c3626-c6ad-7139-a570-62da4e656a1a
Co-authored-by: Amp <amp@ampcode.com>
2026-02-07 14:00:06 -08:00
Luke Mino-Altherr
0121a5532e Fix FK constraint violation in bulk_ingest by filtering dropped assets
Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019c3626-c6ad-7139-a570-62da4e656a1a
2026-02-06 20:11:54 -08:00
Luke Mino-Altherr
0ff02860bc Enable SQLite foreign key enforcement
Amp-Thread-ID: https://ampcode.com/threads/T-019c3626-c6ad-7139-a570-62da4e656a1a
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 19:32:06 -08:00
Luke Mino-Altherr
0bc3a6a377 Add optional blake3 hashing during asset scanning
- Make blake3 import lazy in hashing.py (only imported when needed)
- Add compute_hashes parameter to AssetSeeder.start(), build_asset_specs(), and seed_assets()
- Fix missing tag clearing: include is_missing states in sync when update_missing_tags=True
- Clear is_missing flag on cache states when files are restored with matching mtime/size
- Fix validation error serialization in routes.py (use json.loads(ve.json()))

Amp-Thread-ID: https://ampcode.com/threads/T-019c3614-56d4-74a8-a717-19922d6dbbee
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 19:22:56 -08:00
Luke Mino-Altherr
59bff14c97 Fix magic number and function name typo
- Add MAX_SAFETENSORS_HEADER_SIZE constant in metadata_extract.py
- Fix double 'compute' typo: compute_compute_blake3_hash_async → compute_blake3_hash_async

Amp-Thread-ID: https://ampcode.com/threads/T-019c3550-4dbc-7301-a5e8-e6e23aa2d7b1
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 15:37:18 -08:00
Luke Mino-Altherr
9b284163f5 Fix inconsistent nullability handling for size_bytes in routes.py
Since size_bytes is declared as non-nullable (nullable=False, default=0) in
the Asset model, simplify the conditional checks:
- Use 'if item.asset else None' when the asset relationship might be None
- Access size_bytes directly when asset is guaranteed to exist (create endpoints)

Amp-Thread-ID: https://ampcode.com/threads/T-019c354e-cbfb-77d8-acdd-0d066c16006e
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 15:36:03 -08:00
Luke Mino-Altherr
593cc980e9 Fix type annotation: use Callable[[str], bool] instead of callable
Amp-Thread-ID: https://ampcode.com/threads/T-019c354d-d627-7233-864d-1e6f7a4b8caa
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 15:34:37 -08:00
Luke Mino-Altherr
72b6c3f065 Fix concurrency issues in AssetSeeder
- Fix race in mark_missing_outside_prefixes: set state to RUNNING inside
  lock before operations, restore to IDLE in finally block to prevent
  concurrent start() calls

- Fix timing consistency: capture perf_counter before _update_progress
  for consistent event timing

Amp-Thread-ID: https://ampcode.com/threads/T-019c354b-e7d7-7309-aa0e-79e5e7dff2b7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 15:33:30 -08:00
Luke Mino-Altherr
2af0cf18f5 Consolidate duplicate delete_temp_file_if_exists function
- Remove duplicate from routes.py
- Import from upload.py instead
- Rename to public API (remove leading underscore)

Amp-Thread-ID: https://ampcode.com/threads/T-019c3549-c245-7628-950c-dd6826185394
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 15:31:28 -08:00
Luke Mino-Altherr
6edcd690b6 chore: remove development scripts
Amp-Thread-ID: https://ampcode.com/threads/T-019c3542-377d-76ea-b10a-e551d317d92f
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 15:24:06 -08:00
Luke Mino-Altherr
6708c02446 Fix is_missing state updates for asset cache states on startup
- Add bulk_update_is_missing() to efficiently update is_missing flag
- Update sync_cache_states_with_filesystem() to mark non-existent files as is_missing=True
- Call restore_cache_states_by_paths() in batch_insert_seed_assets() to restore
  previously-missing states when files reappear during scanning

Amp-Thread-ID: https://ampcode.com/threads/T-019c3177-e591-7666-ac6b-7e05c71c8ebf
Co-authored-by: Amp <amp@ampcode.com>
2026-02-06 10:09:19 -08:00
Luke Mino-Altherr
58f70f2f92 docs: remove background-asset-seeder.md
Amp-Thread-ID: https://ampcode.com/threads/T-019c316d-13f7-77f8-b92b-ea7276c3e09c
Co-authored-by: Amp <amp@ampcode.com>
2026-02-05 21:35:58 -08:00
Luke Mino-Altherr
32011c403b refactor(bulk_ingest): improve variable naming and add typed dicts
- Rename shorthand variables to explicit names (sp -> spec, aid -> asset_id, etc.)
- Move imports to top of file
- Add TypedDict definitions for AssetRow, CacheStateRow, AssetInfoRow, TagRow, MetadataRow
- Replace bare dict types with typed alternatives

Amp-Thread-ID: https://ampcode.com/threads/T-019c316d-13f7-77f8-b92b-ea7276c3e09c
Co-authored-by: Amp <amp@ampcode.com>
2026-02-05 21:32:01 -08:00
Luke Mino-Altherr
58ddf46c0a Rename bulk_set_needs_verify to bulk_update_needs_verify for readability
Amp-Thread-ID: https://ampcode.com/threads/T-019c3167-e8f1-7409-904b-5fc0edaeef37
Co-authored-by: Amp <amp@ampcode.com>
2026-02-05 21:25:58 -08:00
Luke Mino-Altherr
9222ff6d81 feat: non-destructive asset pruning with is_missing flag
- Add is_missing column to AssetCacheState for soft-delete
- Replace hard-delete pruning with mark_cache_states_missing_outside_prefixes
- Auto-restore missing cache states when files are re-scanned
- Filter out missing cache states from queries by default
- Rename functions for clarity:
  - mark_cache_states_missing_outside_prefixes (was delete_cache_states_outside_prefixes)
  - get_unreferenced_unhashed_asset_ids (was get_orphaned_seed_asset_ids)
  - mark_assets_missing_outside_prefixes (was prune_orphaned_assets)
  - mark_missing_outside_prefixes_safely (was prune_orphans_safely)
- Add restore_cache_states_by_paths for explicit restoration
- Add cleanup_unreferenced_assets for explicit hard-delete when needed
- Update API endpoint /api/assets/prune to use new soft-delete behavior

This preserves user metadata (tags, etc.) when base directories change,
allowing assets to be restored when the original paths become available again.

Amp-Thread-ID: https://ampcode.com/threads/T-019c3114-bf28-73a9-a4d2-85b208fd5462
Co-authored-by: Amp <amp@ampcode.com>
2026-02-05 21:21:46 -08:00
Luke Mino-Altherr
ebb2f5b0e9 refactor: make scanner helper functions public
Rename _sync_root_safely, _prune_orphans_safely, _collect_paths_for_roots,
_build_asset_specs, and _insert_asset_specs to remove underscore prefix
since they are used by seeder.py as part of the public API.

Amp-Thread-ID: https://ampcode.com/threads/T-019c3037-df32-7138-99d8-b4b824d896b3
Co-authored-by: Amp <amp@ampcode.com>
2026-02-05 19:01:46 -08:00
Luke Mino-Altherr
28c4b58dd6 Make ingest_file_from_path and register_existing_asset private
Amp-Thread-ID: https://ampcode.com/threads/T-019c2fe5-a3de-71cc-a6e5-67fe944a101e
Co-authored-by: Amp <amp@ampcode.com>
2026-02-05 14:26:36 -08:00
Luke Mino-Altherr
56e9a75ca2 Decouple orphan pruning from asset seeding
- Remove automatic pruning from scan loop to prevent partial scans from
  deleting assets belonging to other roots
- Add get_all_known_prefixes() helper to get prefixes for all root types
- Add prune_orphans() method to AssetSeeder for explicit pruning
- Add prune_first parameter to start() for optional pre-scan pruning
- Add POST /api/assets/prune endpoint for explicit pruning via API
- Update main.py startup to use prune_first=True for full startup scans
- Add tests for new prune_orphans functionality

Fixes issue where a models-only scan would delete all input/output assets.

Amp-Thread-ID: https://ampcode.com/threads/T-019c2ba0-e004-7229-81bf-452b2f7f57a1
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 19:38:29 -08:00
Luke Mino-Altherr
8fb77c080f feat(assets): add background asset seeder for non-blocking startup
- Add AssetSeeder singleton class with thread management and cancellation
- Support IDLE/RUNNING/CANCELLING state machine with thread-safe access
- Emit WebSocket events for scan progress (started, progress, completed, cancelled, error)
- Update main.py to use non-blocking asset_seeder.start() at startup
- Add shutdown() call in finally block for graceful cleanup
- Update POST /api/assets/seed to return 202 Accepted, support ?wait=true
- Add GET /api/assets/seed/status and POST /api/assets/seed/cancel endpoints
- Update test helper to use ?wait=true for synchronous behavior
- Add 17 unit tests covering state transitions, cancellation, and thread safety
- Log scan configuration (models directory, input/output paths) at scan start

Amp-Thread-ID: https://ampcode.com/threads/T-019c2b45-e6e8-740a-b38b-b11daea8d094
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 17:02:57 -08:00
Luke Mino-Altherr
fa3749ced7 Add TypedDict types to scanner and bulk_ingest
Amp-Thread-ID: https://ampcode.com/threads/T-019c2af9-4d41-73e9-b38d-78d06bc28a3f
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:32:26 -08:00
Luke Mino-Altherr
16b5d9112b Fix path traversal validation to return 400 instead of 500
Catch ValueError from resolve_destination_from_tags in the upload
endpoint so that invalid path components like '..' return a 400
BAD_REQUEST error instead of falling through to the 500 handler.

Amp-Thread-ID: https://ampcode.com/threads/T-019c2af2-7c87-7263-88b0-9feca1c31b3c
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:24:51 -08:00
Luke Mino-Altherr
abeec3072b refactor(assets): extract scanner logic into service modules
- Create file_utils.py with shared file utilities:
  - get_mtime_ns() - extract mtime in nanoseconds from stat
  - get_size_and_mtime_ns() - get both size and mtime
  - verify_file_unchanged() - check file matches DB mtime/size
  - list_files_recursively() - recursive directory listing

- Create bulk_ingest.py for bulk operations:
  - BulkInsertResult dataclass
  - batch_insert_seed_assets() - batch insert with conflict handling
  - prune_orphaned_assets() - clean up orphaned assets

- Update scanner.py to use new service modules instead of
  calling database queries directly

- Update ingest.py to use shared get_size_and_mtime_ns()

- Export new functions from services/__init__.py

Amp-Thread-ID: https://ampcode.com/threads/T-019c2ae7-f701-716a-a0dd-1feb988732fb
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:17:31 -08:00
Luke Mino-Altherr
b23302f372 refactor(assets): consolidate duplicated query utilities and remove unused code
- Extract shared helpers to database/queries/common.py:
  - MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks, iter_row_chunks
  - build_visible_owner_clause

- Remove duplicate _compute_filename_for_asset, consolidate in path_utils.py

- Remove unused get_asset_info_with_tags (duplicated get_asset_detail)

- Remove redundant __all__ from cache_state.py

- Make internal helpers private (_check_is_scalar)

Amp-Thread-ID: https://ampcode.com/threads/T-019c2ad9-9432-7451-94a8-79287dbbb19e
Co-authored-by: Amp <amp@ampcode.com>
2026-02-04 15:04:30 -08:00
Luke Mino-Altherr
adf6eb73fd refactor: eliminate manager layer, routes call services directly
- Delete app/assets/manager.py
- Move upload logic (upload_from_temp_path, create_from_hash) to ingest service
- Add HashMismatchError and DependencyMissingError to ingest service
- Add UploadResult schema for upload responses
- Update routes.py to import services directly and do schema conversion inline
- Add asset lookup/listing service functions to asset_management.py

Routes now call the service layer directly, removing an unnecessary
layer of indirection. The manager was only converting between service
dataclasses and Pydantic response schemas.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 14:50:11 -08:00
Luke Mino-Altherr
5259959fef refactor: require blake3 package directly in hashing module
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:42:11 -08:00
Luke Mino-Altherr
5474d8bf84 chore: consolidate service imports in manager.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:36:06 -08:00
Luke Mino-Altherr
9290e26e9f refactor: add explicit types to asset service functions
- Add typed result dataclasses: IngestResult, AddTagsResult,
  RemoveTagsResult, SetTagsResult, TagUsage
- Add UserMetadata type alias for user_metadata parameters
- Type helper functions with Session parameters
- Use TypedDicts at query layer to avoid circular imports
- Update manager.py and tests to use attribute access

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:32:14 -08:00
Luke Mino-Altherr
37ecc5b663 chore: remove obvious/self-documenting comments from assets package
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:14:11 -08:00
Luke Mino-Altherr
80d99e7b63 chore: remove module-level comments and docstrings from assets package
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:04:47 -08:00
Luke Mino-Altherr
d8cb122dfb chore: sort imports in assets package
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 20:02:52 -08:00
Luke Mino-Altherr
0f75def5b5 refactor: move scanner.py out of services to top-level assets module
Scanner is used externally by main.py and server.py for startup/maintenance,
not as part of the regular service layer. Moving it to app/assets/scanner.py
makes the public API clearer.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 19:56:29 -08:00
Luke Mino-Altherr
6b1f9f7755 refactor: convert asset tests to table-driven parametrized tests
- test_metadata.py: consolidate 7 filter type classes into parametrized tests
- test_asset.py: parametrize exists, get, and upsert test cases
- test_cache_state.py: parametrize upsert and delete scenarios
- test_crud.py: consolidate error response tests into single parametrized test
- test_list_filter.py: consolidate invalid query tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 19:50:16 -08:00
Luke Mino-Altherr
3311b13740 chore: remove unused re-exports from conftest.py
The helper functions are already imported directly from helpers.py
by all test files, so the backwards compatibility re-export is dead code.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 19:22:24 -08:00
Luke Mino-Altherr
bf7fbb6317 chore: remove unused get_utc_now import
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:47:11 -08:00
Luke Mino-Altherr
5571508e61 refactor: use query functions instead of direct ORM modifications in service layer
Add update_asset_info_name and update_asset_info_updated_at query functions
and update asset_management.py to use them instead of modifying ORM objects
directly. This ensures the service layer only uses explicit operations from
the queries package.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:44:23 -08:00
Luke Mino-Altherr
e3b8e512ca refactor: use explicit dataclasses instead of ORM objects in service layer
Replace dict/ORM object returns with explicit dataclasses to fix
DetachedInstanceError when accessing ORM attributes after session closes.

- Add app/assets/services/schemas.py with AssetData, AssetInfoData,
  AssetDetailResult, and RegisterAssetResult dataclasses
- Update asset_management.py and ingest.py to return dataclasses
- Update manager.py to use attribute access on dataclasses
- Fix created_new to be False in create_asset_from_hash (content exists)
- Add DependencyMissingError for better blake3 missing error handling
- Update tests to use attribute access instead of dict subscripting

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 18:39:07 -08:00
Luke Mino-Altherr
ea01cd665d fix: resolve test import errors and module collision in assets_test
Extract helper functions from conftest.py to a dedicated helpers.py module
to fix import resolution issues when pytest processes subdirectories.
Rename test_tags.py to test_tags_api.py to avoid module name collision
with queries/test_tags.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:57:52 -08:00
Luke Mino-Altherr
ccfc5dedd4 fix: handle missing blake3 module gracefully to prevent server crash
Make blake3 an optional import that fails gracefully at import time,
with a clear error message when hashing functions are actually called.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:46:33 -08:00
Luke Mino-Altherr
e9ca190098 refactor: remove try-finally wrapper in seed_assets by extracting helpers
Extract focused helper functions to eliminate the try-finally block that
wrapped ~50 lines just for logging. The new helpers (_collect_paths_for_roots,
_build_asset_specs, _insert_asset_specs) make seed_assets a simple linear flow.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:33:54 -08:00
Luke Mino-Altherr
ed60e93696 refactor: flatten nested try blocks and if statements in assets package
Extract helper functions to eliminate nested try-except blocks in scanner.py
and remove duplicated type-checking logic in asset_info.py. Simplify nested
conditionals in asset_management.py for clearer control flow.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:28:16 -08:00
Luke Mino-Altherr
fef2f01671 refactor: improve function naming for clarity and consistency
Rename functions to use clearer verb-based names:
- pick_best_live_path → select_best_live_path
- escape_like_prefix → escape_sql_like_string
- list_tree → list_files_recursively
- check_asset_file_fast → verify_asset_file_unchanged
- _seed_from_paths_batch → _batch_insert_assets_from_paths
- reconcile_cache_states_for_root → sync_cache_states_with_filesystem
- touch_asset_info_by_id → update_asset_info_access_time
- replace_asset_info_metadata_projection → set_asset_info_metadata
- expand_metadata_to_rows → convert_metadata_to_rows
- _rows_per_stmt → _calculate_rows_per_statement
- ensure_within_base → validate_path_within_base
- _cleanup_temp → _delete_temp_file_if_exists
- validate_hash_format → normalize_and_validate_hash
- get_relative_to_root_category_path_of_asset → get_asset_category_and_relative_path

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 14:20:36 -08:00
Luke Mino-Altherr
481a2fa263 refactor: rename functions to verb-based naming convention
Rename functions across app/assets/ to follow verb-based naming:
- is_scalar → check_is_scalar
- project_kv → expand_metadata_to_rows
- _visible_owner_clause → _build_visible_owner_clause
- _chunk_rows → _iter_row_chunks
- _at_least_one → _validate_at_least_one_field
- _tags_norm → _normalize_tags_field
- _ser_dt → _serialize_datetime
- _ser_updated → _serialize_updated_at
- _error_response → _build_error_response
- _validation_error_response → _build_validation_error_response
- file_sender → stream_file_chunks
- seed_assets_endpoint → seed_assets
- utcnow → get_utc_now
- _safe_sort_field → _validate_sort_field
- _safe_filename → _sanitize_filename
- fast_asset_file_check → check_asset_file_fast
- prefixes_for_root → get_prefixes_for_root
- blake3_hash → compute_blake3_hash
- blake3_hash_async → compute_blake3_hash_async
- _is_within → _check_is_within
- _rel → _compute_relative

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 13:58:14 -08:00
Luke Mino-Altherr
11ca1995a3 fix: remaining ruff linting errors in services tests
- Remove unused os imports in conftest.py and test_ingest.py
- Remove unused Tag import in test_asset_management.py
- Remove unused ensure_tags_exist import in test_ingest.py
- Fix unused info2 variable in test_asset_management.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 13:31:38 -08:00
Luke Mino-Altherr
4e02245012 fix: ruff linting errors and add comprehensive test coverage for asset queries
- Fix unused imports in routes.py, asset.py, manager.py, asset_management.py, ingest.py
- Fix whitespace issues in upload.py, asset_info.py, ingest.py
- Fix typo in manager.py (stray character after result["asset"])
- Fix broken import in test_metadata.py (project_kv moved to asset_info.py)
- Add fixture override in queries/conftest.py for unit test isolation

Add 48 new tests covering all previously untested query functions:
- asset.py: upsert_asset, bulk_insert_assets
- cache_state.py: upsert_cache_state, delete_cache_states_outside_prefixes,
  get_orphaned_seed_asset_ids, delete_assets_by_ids, get_cache_states_for_prefixes,
  bulk_set_needs_verify, delete_cache_states_by_ids, delete_orphaned_seed_asset,
  bulk_insert_cache_states_ignore_conflicts, get_cache_states_by_paths_and_asset_ids
- asset_info.py: insert_asset_info, get_or_create_asset_info,
  update_asset_info_timestamps, replace_asset_info_metadata_projection,
  bulk_insert_asset_infos_ignore_conflicts, get_asset_info_ids_by_ids
- tags.py: bulk_insert_tags_and_meta

Total: 119 tests pass (up from 71)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 13:21:12 -08:00
Luke Mino-Altherr
9f9db2c2c2 refactor: extract multipart upload parsing from routes
- Add app/assets/api/upload.py with parse_multipart_upload() for HTTP parsing
- Add ParsedUpload dataclass to schemas_in.py
- Add domain exceptions (AssetValidationError, AssetNotFoundError, HashMismatchError)
- Add manager.process_upload() with domain exceptions (no HTTP status codes)
- Routes map domain exceptions to HTTP responses
- Slim down upload_asset route to ~20 lines (was ~150)

Amp-Thread-ID: https://ampcode.com/threads/T-019c2519-abe1-738a-ad2e-29ece17c0e42
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
e987bd268f Move get_comfy_models_folders to path_utils.py to avoid late import
Amp-Thread-ID: https://ampcode.com/threads/T-019c2510-33fa-7199-ae4b-bc31102277a7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
2eb100adf9 Refactor helpers.py: move functions to their respective modules
- Move scanner-only functions to scanner.py
- Move query-only functions (is_scalar, project_kv) to asset_info.py
- Move get_query_dict to routes.py
- Create path_utils.py service for path-related functions
- Reduce helpers.py to shared utilities only

Amp-Thread-ID: https://ampcode.com/threads/T-019c2510-33fa-7199-ae4b-bc31102277a7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
a02f160e20 Move hashing.py to services directory
Amp-Thread-ID: https://ampcode.com/threads/T-019c2510-33fa-7199-ae4b-bc31102277a7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
c3105b1174 refactor: move bulk_ops to queries and scanner service
- Delete bulk_ops.py, moving logic to appropriate layers
- Add bulk insert query functions:
  - queries/asset.bulk_insert_assets
  - queries/cache_state.bulk_insert_cache_states_ignore_conflicts
  - queries/cache_state.get_cache_states_by_paths_and_asset_ids
  - queries/asset_info.bulk_insert_asset_infos_ignore_conflicts
  - queries/asset_info.get_asset_info_ids_by_ids
  - queries/tags.bulk_insert_tags_and_meta
- Move seed_from_paths_batch orchestration to scanner._seed_from_paths_batch

Amp-Thread-ID: https://ampcode.com/threads/T-019c24fd-157d-776a-ad24-4f19cf5d3afe
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
64d2f51dfc refactor: move scanner to services layer with pure query extraction
- Move app/assets/scanner.py to app/assets/services/scanner.py
- Extract pure queries from fast_db_consistency_pass:
  - get_cache_states_for_prefixes()
  - bulk_set_needs_verify()
  - delete_cache_states_by_ids()
  - delete_orphaned_seed_asset()
- Split prune_orphaned_assets into pure queries:
  - delete_cache_states_outside_prefixes()
  - get_orphaned_seed_asset_ids()
  - delete_assets_by_ids()
- Add reconcile_cache_states_for_root() service function
- Add prune_orphaned_assets() service function
- Remove function injection pattern
- Update imports in main.py, server.py, routes.py

Amp-Thread-ID: https://ampcode.com/threads/T-019c24f1-3385-701b-87e0-8b6bc87e841b
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
fba4570e49 refactor: move in-function imports to top-level and remove keyword-only argument pattern
- Move imports from inside functions to module top-level in:
  - app/assets/database/queries/asset.py
  - app/assets/database/queries/asset_info.py
  - app/assets/database/queries/cache_state.py
  - app/assets/manager.py
  - app/assets/services/asset_management.py
  - app/assets/services/ingest.py

- Remove keyword-only argument markers (*,) from app/assets/ to match codebase conventions

Amp-Thread-ID: https://ampcode.com/threads/T-019c24eb-bfa2-727f-8212-8bc976048604
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
15ee03f65c Refactor asset database: separate business logic from queries
Architecture changes:
- API Routes -> manager.py (thin adapter) -> services/ (business logic) -> queries/ (atomic DB ops)
- Services own session lifecycle via create_session()
- Queries accept Session as parameter, do single-table atomic operations

New app/assets/services/ layer:
- __init__.py - exports all service functions
- ingest.py - ingest_file_from_path(), register_existing_asset()
- asset_management.py - get_asset_detail(), update_asset_metadata(), delete_asset_reference(), set_asset_preview()
- tagging.py - apply_tags(), remove_tags(), list_tags()

Removed from queries/asset_info.py:
- ingest_fs_asset (moved to services/ingest.py as ingest_file_from_path)
- update_asset_info_full (moved to services/asset_management.py as update_asset_metadata)
- create_asset_info_for_existing_asset (moved to services/ingest.py as register_existing_asset)

Updated manager.py:
- Now a thin adapter that transforms API schemas to/from service calls
- Delegates all business logic to services layer
- No longer imports sqlalchemy.orm.Session or models directly

Test updates:
- Fixed test_cache_state.py import of pick_best_live_path (moved to helpers.py)
- Added comprehensive service layer tests (41 new tests)
- All 112 query + service tests pass

Amp-Thread-ID: https://ampcode.com/threads/T-019c24e2-7ae4-707f-ad19-c775ed8b82b5
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
70a600baf0 chore: remove unused Asset import from manager.py
Amp-Thread-ID: https://ampcode.com/threads/T-019c24bb-475b-7442-9ff9-8288edea3345
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
Luke Mino-Altherr
17ad7e393f refactor(assets): split queries.py into modular query modules
Split the ~1000 line app/assets/database/queries.py into focused modules:

- queries/asset.py - Asset entity queries (asset_exists_by_hash, get_asset_by_hash)
- queries/asset_info.py - AssetInfo queries (~15 functions)
- queries/cache_state.py - AssetCacheState queries (list_cache_states_by_asset_id,
  pick_best_live_path, prune_orphaned_assets, fast_db_consistency_pass)
- queries/tags.py - Tag queries (8 functions including ensure_tags_exist,
  add/remove tag functions, list_tags_with_usage)
- queries/__init__.py - Re-exports all public functions for backward compatibility

Also adds comprehensive unit tests using in-memory SQLite:
- tests-unit/assets_test/queries/conftest.py - Session fixture
- tests-unit/assets_test/queries/test_asset.py - 5 tests
- tests-unit/assets_test/queries/test_asset_info.py - 23 tests
- tests-unit/assets_test/queries/test_cache_state.py - 8 tests
- tests-unit/assets_test/queries/test_metadata.py - 12 tests for _apply_metadata_filter
- tests-unit/assets_test/queries/test_tags.py - 23 tests

All 71 unit tests pass. Existing integration tests unaffected.

Amp-Thread-ID: https://ampcode.com/threads/T-019c24bb-475b-7442-9ff9-8288edea3345
Co-authored-by: Amp <amp@ampcode.com>
2026-02-03 13:08:04 -08:00
54 changed files with 8174 additions and 2717 deletions

View File

@@ -0,0 +1,37 @@
"""
Add is_missing column to asset_cache_state for non-destructive soft-delete
Revision ID: 0002_add_is_missing
Revises: 0001_assets
Create Date: 2025-02-05 00:00:00
"""
from alembic import op
import sqlalchemy as sa
revision = "0002_add_is_missing"
down_revision = "0001_assets"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"asset_cache_state",
sa.Column(
"is_missing",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
op.create_index(
"ix_asset_cache_state_is_missing",
"asset_cache_state",
["is_missing"],
)
def downgrade() -> None:
op.drop_index("ix_asset_cache_state_is_missing", table_name="asset_cache_state")
op.drop_column("asset_cache_state", "is_missing")

View File

@@ -1,19 +1,39 @@
import logging
import uuid
import urllib.parse
import os
import contextlib
from aiohttp import web
import urllib.parse
import uuid
from typing import Any
from aiohttp import web
from pydantic import ValidationError
import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
from app.assets.scanner import seed_assets
import folder_paths
from app import user_manager
from app.assets.api import schemas_in, schemas_out
from app.assets.api.schemas_in import (
AssetValidationError,
UploadError,
)
from app.assets.api.upload import (
delete_temp_file_if_exists,
parse_multipart_upload,
)
from app.assets.seeder import asset_seeder
from app.assets.services import (
DependencyMissingError,
HashMismatchError,
apply_tags,
asset_exists,
create_from_hash,
delete_asset_reference,
get_asset_detail,
list_assets_page,
list_tags,
remove_tags,
resolve_asset_for_download,
update_asset_metadata,
upload_from_temp_path,
)
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -21,36 +41,80 @@ USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
query_dict = {
key: request.query.getall(key)
if len(request.query.getall(key)) > 1
else request.query.get(key)
for key in request.query.keys()
}
return query_dict
# Note to any custom node developers reading this code:
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
def register_assets_system(
app: web.Application, user_manager_instance: user_manager.UserManager
) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _build_error_response(
status: int, code: str, message: str, details: dict | None = None
) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response:
import json
errors = json.loads(ve.json())
return _build_error_response(400, code, "Validation failed.", {"errors": errors})
def _validate_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
return _build_error_response(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
if (
algo != "blake3"
or not digest
or any(c for c in digest if c not in "0123456789abcdef")
):
return _build_error_response(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
exists = asset_exists(hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
async def list_assets_route(request: web.Request) -> web.Response:
"""
GET request to list assets.
"""
@@ -58,66 +122,117 @@ async def list_assets(request: web.Request) -> web.Response:
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
return _build_validation_error_response("INVALID_QUERY", ve)
payload = manager.list_assets(
sort = _validate_sort_field(q.sort)
order_candidate = (q.order or "desc").lower()
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
sort=sort,
order=order,
)
summaries = [
schemas_out.AssetSummary(
id=item.info.id,
name=item.info.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.info.created_at,
updated_at=item.info.updated_at,
last_access_time=item.info.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList(
assets=summaries,
total=result.total,
has_more=(q.offset + len(summaries)) < result.total,
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
async def get_asset_route(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
"""
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = manager.get_asset(
result = get_asset_detail(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if not result:
return _build_error_response(
404,
"ASSET_NOT_FOUND",
f"AssetInfo {asset_info_id} not found",
{"id": asset_info_id},
)
payload = schemas_out.AssetDetail(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
)
except ValueError as e:
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}
)
except Exception:
logging.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
result = resolve_asset_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
abs_path = result.abs_path
content_type = result.content_type
filename = result.download_name
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
return _build_error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
return _build_error_response(
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
)
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
cd = f"{disposition}; filename=\"{quoted}\"; filename*=UTF-8''{urllib.parse.quote(quoted)}"
file_size = os.path.getsize(abs_path)
logging.info(
@@ -129,7 +244,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
filename,
)
async def file_sender():
async def stream_file_chunks():
chunk_size = 64 * 1024
with open(abs_path, "rb") as f:
while True:
@@ -139,7 +254,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
yield chunk
return web.Response(
body=file_sender(),
body=stream_file_chunks(),
content_type=content_type,
headers={
"Content-Disposition": cd,
@@ -149,16 +264,18 @@ async def download_asset_content(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
async def create_asset_from_hash_route(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
return _build_validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
result = manager.create_asset_from_hash(
result = create_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
@@ -166,228 +283,183 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
payload_out = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists)
except UploadError as e:
return _build_error_response(e.status, e.code, e.message)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
try:
spec = schemas_in.UploadAssetSpec.model_validate(
{
"tags": parsed.tags_raw,
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
}
)
except ValidationError as ve:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
)
if spec.tags and spec.tags[0] == "models":
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
):
delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)
try:
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist"
)
delete_temp_file_if_exists(parsed.tmp_path)
else:
# Otherwise, we must have a temp file path to ingest
if not parsed.tmp_path or not os.path.exists(parsed.tmp_path):
return _build_error_response(
404,
"ASSET_NOT_FOUND",
"Provided hash not found and no file uploaded.",
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
result = upload_from_temp_path(
temp_path=parsed.tmp_path,
name=spec.name,
tags=spec.tags,
user_metadata=spec.user_metadata or {},
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, e.code, str(e))
except ValueError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "BAD_REQUEST", str(e))
except HashMismatchError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "HASH_MISMATCH", str(e))
except DependencyMissingError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(503, "DEPENDENCY_MISSING", e.message)
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
delete_temp_file_if_exists(parsed.tmp_path)
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
created_new=result.created_new,
)
status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
async def update_asset_route(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
return _build_validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
try:
result = manager.update_asset(
result = update_asset_metadata(
asset_info_id=asset_info_id,
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.AssetUpdated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
updated_at=result.info.updated_at,
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
async def delete_asset_route(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
delete_content_param = request.query.get("delete_content")
delete_content = (
True
if delete_content_param is None
else delete_content_param.lower() not in {"0", "false", "no"}
)
try:
deleted = manager.delete_asset_reference(
deleted = delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
@@ -398,10 +470,12 @@ async def delete_asset(request: web.Request) -> web.Response:
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return _build_error_response(
404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found."
)
return web.Response(status=204)
@@ -415,12 +489,12 @@ async def get_tags(request: web.Request) -> web.Response:
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as e:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
status=400,
import json
return _build_error_response(
400, "INVALID_QUERY", "Invalid query parameters", {"errors": json.loads(e.json())}
)
result = manager.list_tags(
rows, total = list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
@@ -428,87 +502,201 @@ async def get_tags(request: web.Request) -> web.Response:
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
tags = [
schemas_out.TagUsage(name=name, count=count, type=tag_type)
for (name, tag_type, count) in rows
]
payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
json_payload = await request.json()
data = schemas_in.TagsAdd.model_validate(json_payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
return _build_error_response(
400,
"INVALID_BODY",
"Invalid JSON body for tags add.",
{"errors": ve.errors()},
)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
try:
result = manager.add_tags_to_asset(
result = apply_tags(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.TagsAdd(
added=result.added,
already_present=result.already_present,
total_tags=result.total_tags,
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
json_payload = await request.json()
data = schemas_in.TagsRemove.model_validate(json_payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
return _build_error_response(
400,
"INVALID_BODY",
"Invalid JSON body for tags remove.",
{"errors": ve.errors()},
)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
try:
result = manager.remove_tags_from_asset(
result = remove_tags(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.TagsRemove(
removed=result.removed,
not_present=result.not_present,
total_tags=result.total_tags,
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
async def seed_assets_endpoint(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output)."""
async def seed_assets(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output).
Query params:
wait: If "true", block until scan completes (synchronous behavior for tests)
Returns:
202 Accepted if scan started
409 Conflict if scan already running
200 OK with final stats if wait=true
"""
try:
payload = await request.json()
roots = payload.get("roots", ["models", "input", "output"])
except Exception:
roots = ["models", "input", "output"]
valid_roots = [r for r in roots if r in ("models", "input", "output")]
valid_roots = tuple(r for r in roots if r in ("models", "input", "output"))
if not valid_roots:
return _error_response(400, "INVALID_BODY", "No valid roots specified")
return _build_error_response(400, "INVALID_BODY", "No valid roots specified")
try:
seed_assets(tuple(valid_roots))
except Exception:
logging.exception("seed_assets failed for roots=%s", valid_roots)
return _error_response(500, "INTERNAL", "Seed operation failed")
wait_param = request.query.get("wait", "").lower()
should_wait = wait_param in ("true", "1", "yes")
return web.json_response({"seeded": valid_roots}, status=200)
started = asset_seeder.start(roots=valid_roots)
if not started:
return web.json_response({"status": "already_running"}, status=409)
if should_wait:
asset_seeder.wait()
status = asset_seeder.get_status()
return web.json_response(
{
"status": "completed",
"progress": {
"scanned": status.progress.scanned if status.progress else 0,
"total": status.progress.total if status.progress else 0,
"created": status.progress.created if status.progress else 0,
"skipped": status.progress.skipped if status.progress else 0,
},
"errors": status.errors,
},
status=200,
)
return web.json_response({"status": "started"}, status=202)
@ROUTES.get("/api/assets/seed/status")
async def get_seed_status(request: web.Request) -> web.Response:
"""Get current scan status and progress."""
status = asset_seeder.get_status()
return web.json_response(
{
"state": status.state.value,
"progress": {
"scanned": status.progress.scanned,
"total": status.progress.total,
"created": status.progress.created,
"skipped": status.progress.skipped,
}
if status.progress
else None,
"errors": status.errors,
},
status=200,
)
@ROUTES.post("/api/assets/seed/cancel")
async def cancel_seed(request: web.Request) -> web.Response:
"""Request cancellation of in-progress scan."""
cancelled = asset_seeder.cancel()
if cancelled:
return web.json_response({"status": "cancelling"}, status=200)
return web.json_response({"status": "idle"}, status=200)
@ROUTES.post("/api/assets/prune")
async def mark_missing_assets(request: web.Request) -> web.Response:
"""Mark assets as missing when their cache states point to files outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their metadata
are preserved, but cache states are flagged as missing. They can be restored
if the file reappears in a future scan.
Returns:
200 OK with count of marked assets
409 Conflict if a scan is currently running
"""
marked = asset_seeder.mark_missing_outside_prefixes()
if marked == 0 and asset_seeder.get_status().state.value != "IDLE":
return web.json_response(
{"status": "scan_running", "marked": 0},
status=409,
)
return web.json_response({"status": "completed", "marked": marked}, status=200)

View File

@@ -1,4 +1,5 @@
import json
from dataclasses import dataclass
from typing import Any, Literal
from pydantic import (
@@ -10,6 +11,65 @@ from pydantic import (
model_validator,
)
class UploadError(Exception):
"""Error during upload parsing with HTTP status and code (used in HTTP layer only)."""
def __init__(self, status: int, code: str, message: str):
super().__init__(message)
self.status = status
self.code = code
self.message = message
class AssetValidationError(Exception):
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
def __init__(self, code: str, message: str):
super().__init__(message)
self.code = code
self.message = message
class AssetNotFoundError(Exception):
"""Asset or asset content not found."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
class HashMismatchError(Exception):
"""Uploaded file hash does not match provided hash."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
class DependencyMissingError(Exception):
"""A required dependency is not installed."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
@dataclass
class ParsedUpload:
"""Result of parsing a multipart upload request."""
file_present: bool
file_written: int
file_client_name: str | None
tmp_path: str | None
tags_raw: list[str]
provided_name: str | None
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
@@ -21,7 +81,9 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
)
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@@ -61,7 +123,7 @@ class UpdateAssetBody(BaseModel):
user_metadata: dict[str, Any] | None = None
@model_validator(mode="after")
def _at_least_one(self):
def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
@@ -90,7 +152,7 @@ class CreateFromHashBody(BaseModel):
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
def _normalize_tags_field(cls, v):
if v is None:
return []
if isinstance(v, list):
@@ -163,6 +225,7 @@ class UploadAssetSpec(BaseModel):
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
@@ -260,5 +323,7 @@ class UploadAssetSpec(BaseModel):
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@@ -19,7 +19,7 @@ class AssetSummary(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
@@ -40,7 +40,7 @@ class AssetUpdated(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
@@ -59,7 +59,7 @@ class AssetDetail(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None

170
app/assets/api/upload.py Normal file
View File

@@ -0,0 +1,170 @@
import logging
import os
import uuid
from typing import Callable
from aiohttp import web
import folder_paths
from app.assets.api.schemas_in import ParsedUpload, UploadError
def normalize_and_validate_hash(s: str) -> str:
"""
Validate and normalize a hash string.
Returns canonical 'blake3:<hex>' or raises UploadError.
"""
s = s.strip().lower()
if not s:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if ":" not in s:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if (
algo != "blake3"
or not digest
or any(c for c in digest if c not in "0123456789abcdef")
):
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
return f"{algo}:{digest}"
async def parse_multipart_upload(
request: web.Request,
check_hash_exists: Callable[[str], bool],
) -> ParsedUpload:
"""
Parse a multipart/form-data upload request.
Args:
request: The aiohttp request
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
Returns:
ParsedUpload with parsed fields and temp file path
Raises:
UploadError: On validation or I/O errors
"""
if not (request.content_type or "").lower().startswith("multipart/"):
raise UploadError(
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
)
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
raise UploadError(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
if s:
provided_hash = normalize_and_validate_hash(s)
try:
provided_hash_exists = check_hash_exists(provided_hash)
except Exception as e:
logging.warning(
"check_hash_exists failed for hash=%s: %s", provided_hash, e
)
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
)
continue
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
delete_temp_file_if_exists(tmp_path)
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
)
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
)
if (
file_present
and file_written == 0
and not (provided_hash and provided_hash_exists)
):
delete_temp_file_if_exists(tmp_path)
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
return ParsedUpload(
file_present=file_present,
file_written=file_written,
file_client_name=file_client_name,
tmp_path=tmp_path,
tags_raw=tags_raw,
provided_name=provided_name,
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
)
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
"""Safely remove a temp file if it exists."""
if tmp_path:
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except OSError as e:
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)

View File

@@ -1,204 +0,0 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
@@ -20,19 +20,21 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
from app.assets.helpers import get_utc_now
from app.database.models import Base, to_dict
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
infos: Mapped[list[AssetInfo]] = relationship(
@@ -75,17 +77,23 @@ class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
Index("ix_asset_cache_state_is_missing", "is_missing"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"
),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
@@ -99,15 +107,29 @@ class AssetCacheState(Base):
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False
)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset: Mapped[Asset] = relationship(
"Asset",
@@ -143,7 +165,9 @@ class AssetInfo(Base):
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
UniqueConstraint(
"asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"
),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
@@ -196,7 +220,7 @@ class AssetInfoTag(Base):
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
@@ -225,9 +249,7 @@ class Tag(Base):
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -1,976 +0,0 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@@ -0,0 +1,105 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
get_asset_by_hash,
get_existing_asset_ids,
upsert_asset,
)
from app.assets.database.queries.asset_info import (
asset_info_exists_for_asset_id,
bulk_insert_asset_infos_ignore_conflicts,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_info_by_id,
get_asset_info_ids_by_ids,
get_or_create_asset_info,
insert_asset_info,
list_asset_infos_page,
set_asset_info_metadata,
set_asset_info_preview,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_timestamps,
update_asset_info_updated_at,
)
from app.assets.database.queries.cache_state import (
CacheStateRow,
bulk_insert_cache_states_ignore_conflicts,
bulk_update_is_missing,
bulk_update_needs_verify,
delete_assets_by_ids,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
get_cache_states_by_paths_and_asset_ids,
get_cache_states_for_prefixes,
get_unreferenced_unhashed_asset_ids,
list_cache_states_by_asset_id,
mark_cache_states_missing_outside_prefixes,
restore_cache_states_by_paths,
upsert_cache_state,
)
from app.assets.database.queries.tags import (
AddTagsDict,
RemoveTagsDict,
SetTagsDict,
add_missing_tag_for_asset_id,
add_tags_to_asset_info,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_asset_tags,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_asset_info,
set_asset_info_tags,
)
__all__ = [
"AddTagsDict",
"CacheStateRow",
"RemoveTagsDict",
"SetTagsDict",
"add_missing_tag_for_asset_id",
"add_tags_to_asset_info",
"asset_exists_by_hash",
"asset_info_exists_for_asset_id",
"bulk_insert_asset_infos_ignore_conflicts",
"bulk_insert_assets",
"bulk_insert_cache_states_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"delete_asset_info_by_id",
"delete_assets_by_ids",
"delete_cache_states_by_ids",
"delete_orphaned_seed_asset",
"ensure_tags_exist",
"fetch_asset_info_and_asset",
"fetch_asset_info_asset_and_tags",
"get_asset_by_hash",
"get_existing_asset_ids",
"get_asset_info_by_id",
"get_asset_info_ids_by_ids",
"get_asset_tags",
"get_cache_states_by_paths_and_asset_ids",
"get_cache_states_for_prefixes",
"get_or_create_asset_info",
"get_unreferenced_unhashed_asset_ids",
"insert_asset_info",
"list_asset_infos_page",
"list_cache_states_by_asset_id",
"list_tags_with_usage",
"mark_cache_states_missing_outside_prefixes",
"remove_missing_tag_for_asset_id",
"remove_tags_from_asset_info",
"restore_cache_states_by_paths",
"set_asset_info_metadata",
"set_asset_info_preview",
"set_asset_info_tags",
"update_asset_info_access_time",
"update_asset_info_name",
"update_asset_info_timestamps",
"update_asset_info_updated_at",
"upsert_asset",
"upsert_cache_state",
]

View File

@@ -0,0 +1,103 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries.common import calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows. Each dict should have: id, hash, size_bytes, mime_type, created_at."""
if not rows:
return
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)
def get_existing_asset_ids(
session: Session,
asset_ids: list[str],
) -> set[str]:
"""Return the subset of asset_ids that exist in the database."""
if not asset_ids:
return set()
rows = session.execute(
select(Asset.id).where(Asset.id.in_(asset_ids))
).fetchall()
return {row[0] for row in rows}

View File

@@ -0,0 +1,527 @@
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from typing import Sequence
import sqlalchemy as sa
from sqlalchemy import delete, exists, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import (
Asset,
AssetInfo,
AssetInfoMeta,
AssetInfoTag,
Tag,
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
def _check_is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return {"key": key, "ordinal": ordinal, "val_num": num}
if isinstance(value, str):
return {"key": key, "ordinal": ordinal, "val_str": value}
return {"key": key, "ordinal": ordinal, "val_json": value}
def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
if value is None:
return [_scalar_to_row(key, 0, None)]
if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)]
if isinstance(value, list):
if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_info_exists_for_asset_id(
session: Session,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_info_by_id(
session: Session,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def insert_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> AssetInfo | None:
"""Insert a new AssetInfo. Returns None if unique constraint violated."""
now = get_utc_now()
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset_id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
except IntegrityError:
return None
def get_or_create_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> tuple[AssetInfo, bool]:
"""Get existing or create new AssetInfo. Returns (info, created)."""
info = insert_asset_info(
session,
asset_id=asset_id,
owner_id=owner_id,
name=name,
preview_id=preview_id,
)
if info:
return info, True
existing = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset_id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
.unique()
.scalar_one_or_none()
)
if not existing:
raise RuntimeError("Failed to find AssetInfo after insert conflict.")
return existing, False
def update_asset_info_timestamps(
session: Session,
asset_info: AssetInfo,
preview_id: str | None = None,
) -> None:
"""Update timestamps and optionally preview_id on existing AssetInfo."""
now = get_utc_now()
if preview_id and asset_info.preview_id != preview_id:
asset_info.preview_id = preview_id
asset_info.updated_at = now
if asset_info.last_access_time < now:
asset_info.last_access_time = now
session.flush()
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def update_asset_info_access_time(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or get_utc_now()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts
)
)
session.execute(stmt.values(last_access_time=ts))
def update_asset_info_name(
session: Session,
asset_info_id: str,
name: str,
) -> None:
"""Update the name of an AssetInfo."""
now = get_utc_now()
session.execute(
sa.update(AssetInfo)
.where(AssetInfo.id == asset_info_id)
.values(name=name, updated_at=now)
)
def update_asset_info_updated_at(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
) -> None:
"""Update the updated_at timestamp of an AssetInfo."""
ts = ts or get_utc_now()
session.execute(
sa.update(AssetInfo).where(AssetInfo.id == asset_info_id).values(updated_at=ts)
)
def set_asset_info_metadata(
session: Session,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = get_utc_now()
session.flush()
session.execute(
delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)
)
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_asset_info_by_id(
session: Session,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def set_asset_info_preview(
session: Session,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = get_utc_now()
session.flush()
def bulk_insert_asset_infos_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert AssetInfo rows with ON CONFLICT DO NOTHING.
Each dict should have: id, owner_id, name, asset_id, preview_id,
user_metadata, created_at, updated_at, last_access_time
"""
if not rows:
return
ins = sqlite.insert(AssetInfo).on_conflict_do_nothing(
index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]
)
for chunk in iter_chunks(rows, calculate_rows_per_statement(9)):
session.execute(ins, chunk)
def get_asset_info_ids_by_ids(
session: Session,
info_ids: list[str],
) -> set[str]:
"""Query to find which AssetInfo IDs exist in the database."""
if not info_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(info_ids, MAX_BIND_PARAMS):
result = session.execute(select(AssetInfo.id).where(AssetInfo.id.in_(chunk)))
found.update(result.scalars().all())
return found

View File

@@ -0,0 +1,351 @@
import os
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string
class CacheStateRow(NamedTuple):
"""Row from cache state query with joined asset data."""
state_id: int
file_path: str
mtime_ns: int | None
needs_verify: bool
asset_id: str
asset_hash: str | None
size_bytes: int
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
(
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
)
.scalars()
.all()
)
def upsert_cache_state(
session: Session,
asset_id: str,
file_path: str,
mtime_ns: int,
) -> tuple[bool, bool]:
"""Upsert a cache state by file_path. Returns (created, updated).
Also restores cache states that were previously marked as missing.
"""
vals = {
"asset_id": asset_id,
"file_path": file_path,
"mtime_ns": int(mtime_ns),
"is_missing": False,
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
if created:
return True, False
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == file_path)
.where(
sa.or_(
AssetCacheState.asset_id != asset_id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
AssetCacheState.is_missing == True, # noqa: E712
)
)
.values(asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False)
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0
return False, updated
def mark_cache_states_missing_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Mark cache states as missing when file_path doesn't match any valid prefix.
This is a non-destructive soft-delete that preserves user metadata.
Cache states can be restored if the file reappears in a future scan.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states marked as missing
"""
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
result = session.execute(
sa.update(AssetCacheState)
.where(~matches_valid_prefix)
.where(AssetCacheState.is_missing == False) # noqa: E712
.values(is_missing=True)
)
return result.rowcount
def restore_cache_states_by_paths(session: Session, file_paths: list[str]) -> int:
"""Restore cache states that were previously marked as missing.
Called when a file path is re-scanned and found to exist.
Args:
session: Database session
file_paths: List of file paths that exist and should be restored
Returns:
Number of cache states restored
"""
if not file_paths:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.file_path.in_(file_paths))
.where(AssetCacheState.is_missing == True) # noqa: E712
.values(is_missing=False)
)
return result.rowcount
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
"""Get IDs of unhashed assets (hash=None) with no active cache states.
An asset is considered unreferenced if it has no cache states,
or all its cache states are marked as missing.
Returns:
List of asset IDs that are unreferenced
"""
active_cache_state_exists = (
sa.select(sa.literal(1))
.where(AssetCacheState.asset_id == Asset.id)
.where(AssetCacheState.is_missing == False) # noqa: E712
.correlate(Asset)
.exists()
)
unreferenced_subq = sa.select(Asset.id).where(
Asset.hash.is_(None), ~active_cache_state_exists
)
return [row[0] for row in session.execute(unreferenced_subq).all()]
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""Delete assets and their AssetInfos by ID.
Args:
session: Database session
asset_ids: List of asset IDs to delete
Returns:
Number of assets deleted
"""
if not asset_ids:
return 0
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
def get_cache_states_for_prefixes(
session: Session,
prefixes: list[str],
*,
include_missing: bool = False,
) -> list[CacheStateRow]:
"""Get all cache states with paths matching any of the given prefixes.
Args:
session: Database session
prefixes: List of absolute directory prefixes to match
include_missing: If False (default), exclude cache states marked as missing
Returns:
List of cache state rows with joined asset data, ordered by asset_id, state_id
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
)
if not include_missing:
query = query.where(AssetCacheState.is_missing == False) # noqa: E712
rows = session.execute(
query.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
).all()
return [
CacheStateRow(
state_id=row[0],
file_path=row[1],
mtime_ns=row[2],
needs_verify=row[3],
asset_id=row[4],
asset_hash=row[5],
size_bytes=int(row[6] or 0),
)
for row in rows
]
def bulk_update_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
"""Set needs_verify flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(needs_verify=value)
)
return result.rowcount
def bulk_update_is_missing(session: Session, state_ids: list[int], value: bool) -> int:
"""Set is_missing flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(is_missing=value)
)
return result.rowcount
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
"""Delete cache states by their IDs.
Returns: Number of rows deleted
"""
if not state_ids:
return 0
result = session.execute(
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
)
return result.rowcount
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
"""Delete a seed asset (hash is None) and its AssetInfos.
Returns: True if asset was deleted, False if not found
"""
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
asset = session.get(Asset, asset_id)
if asset:
session.delete(asset)
return True
return False
def bulk_insert_cache_states_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: asset_id, file_path, mtime_ns
The is_missing field is automatically set to False for new inserts.
"""
if not rows:
return
enriched_rows = [{**row, "is_missing": False} for row in rows]
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
index_elements=[AssetCacheState.file_path]
)
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(4)):
session.execute(ins, chunk)
def get_cache_states_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query cache states to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
paths = list(path_to_asset.keys())
winners: set[str] = set()
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
result = session.execute(
select(AssetCacheState.file_path).where(
AssetCacheState.file_path.in_(chunk),
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
)
)
winners.update(result.scalars().all())
return winners

View File

@@ -0,0 +1,37 @@
"""Shared utilities for database query modules."""
from typing import Iterable
import sqlalchemy as sa
from app.assets.database.models import AssetInfo
MAX_BIND_PARAMS = 800
def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count."""
return max(1, MAX_BIND_PARAMS // max(1, cols))
def iter_chunks(seq, n: int):
"""Yield successive n-sized chunks from seq."""
for i in range(0, len(seq), n):
yield seq[i : i + n]
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return
rows_per_stmt = calculate_rows_per_statement(cols_per_row)
for i in range(0, len(rows), rows_per_stmt):
yield rows[i : i + rows_per_stmt]
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@@ -0,0 +1,349 @@
from typing import Iterable, Sequence, TypedDict
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
class AddTagsDict(TypedDict):
added: list[str]
already_present: list[str]
total_tags: list[str]
class RemoveTagsDict(TypedDict):
removed: list[str]
not_present: list[str]
total_tags: list[str]
class SetTagsDict(TypedDict):
added: list[str]
removed: list[str]
total: list[str]
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
]
def set_asset_info_tags(
session: Session,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsDict:
desired = normalize_tags(tags)
current = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def add_tags_to_asset_info(
session: Session,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: AssetInfo | None = None,
) -> AddTagsDict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
asset_info_id: str,
tags: Sequence[str],
) -> RemoveTagsDict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def add_missing_tag_for_asset_id(
session: Session,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
)
)
def remove_missing_tag_for_asset_id(
session: Session,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(
sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)
),
AssetInfoTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(build_visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: List of dicts with keys: asset_info_id, tag_name, origin, added_at
meta_rows: List of dicts with keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_tags = sqlite.insert(AssetInfoTag).on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetInfoMeta).on_conflict_do_nothing(
index_elements=[
AssetInfoMeta.asset_info_id,
AssetInfoMeta.key,
AssetInfoMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@@ -1,62 +0,0 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@@ -1,75 +0,0 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -1,52 +1,36 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Any
import folder_paths
from typing import Literal, Sequence
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
def select_best_live_path(states: Sequence) -> str:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
alive = [
s
for s in states
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
ALLOWED_ROOTS: tuple[Literal["models", "input", "output"], ...] = (
"models",
"input",
"output",
)
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
@@ -54,173 +38,11 @@ def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def utcnow() -> datetime:
def get_utc_now() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
@@ -229,84 +51,3 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -1,516 +0,0 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

View File

@@ -1,263 +1,396 @@
import contextlib
import time
import logging
import os
import sqlalchemy
import time
from typing import Literal, TypedDict
import folder_paths
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
from app.assets.database.queries import (
add_missing_tag_for_asset_id,
bulk_update_is_missing,
bulk_update_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
ensure_tags_exist,
get_cache_states_for_prefixes,
remove_missing_tag_for_asset_id,
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.services.bulk_ingest import (
SeedAssetSpec,
batch_insert_seed_assets,
mark_assets_missing_outside_prefixes,
)
from app.assets.services.file_utils import (
get_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.hashing import compute_blake3_hash
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
)
from app.database.db import create_session, dependencies_available
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
class _StateInfo(TypedDict):
sid: int
fp: str
exists: bool
fast_ok: bool
needs_verify: bool
class _AssetAccumulator(TypedDict):
hash: str | None
size_db: int
states: list[_StateInfo]
RootType = Literal["models", "input", "output"]
def get_prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output")
return [
os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root)
]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(ValueError):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def sync_cache_states_with_filesystem(
session,
root: RootType,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Reconcile cache states with filesystem for a root.
- Toggle needs_verify per state using fast mtime/size check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
Args:
session: Database session
root: Root type to scan
collect_existing_paths: If True, return set of surviving file paths
update_missing_tags: If True, update 'missing' tags based on file status
Returns:
Set of surviving absolute paths if collect_existing_paths=True, else None
"""
Scan the given roots and seed the assets into the database.
prefixes = get_prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
rows = get_cache_states_for_prefixes(
session, prefixes, include_missing=update_missing_tags
)
by_asset: dict[str, _AssetAccumulator] = {}
for row in rows:
acc = by_asset.get(row.asset_id)
if acc is None:
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "states": []}
by_asset[row.asset_id] = acc
fast_ok = False
try:
exists = True
fast_ok = verify_file_unchanged(
mtime_db=row.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except PermissionError:
exists = True
logging.debug("Permission denied accessing %s", row.file_path)
except OSError as e:
exists = False
logging.debug("OSError checking %s: %s", row.file_path, e)
acc["states"].append(
{
"sid": row.state_id,
"fp": row.file_path,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": row.needs_verify,
}
)
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
to_mark_missing: list[int] = []
to_clear_missing: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
to_mark_missing.append(s["sid"])
continue
if s["fast_ok"]:
to_clear_missing.append(s["sid"])
if s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing:
delete_orphaned_seed_asset(session, aid)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok:
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
try:
remove_missing_tag_for_asset_id(session, asset_id=aid)
except Exception as e:
logging.warning("Failed to remove missing tag for asset %s: %s", aid, e)
elif update_missing_tags:
try:
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
except Exception as e:
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
delete_cache_states_by_ids(session, stale_state_ids)
stale_set = set(stale_state_ids)
to_mark_missing = [sid for sid in to_mark_missing if sid not in stale_set]
bulk_update_is_missing(session, to_mark_missing, value=True)
bulk_update_is_missing(session, to_clear_missing, value=False)
bulk_update_needs_verify(session, to_set_verify, value=True)
bulk_update_needs_verify(session, to_clear_verify, value=False)
return survivors if collect_existing_paths else None
def sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's cache states with the filesystem.
Returns survivors (existing paths) or empty set on failure.
"""
try:
with create_session() as sess:
survivors = sync_cache_states_with_filesystem(
sess,
root,
collect_existing_paths=True,
update_missing_tags=True,
)
sess.commit()
return survivors or set()
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", root, e)
return set()
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
"""Mark cache states as missing when outside the given prefixes.
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
"""
try:
with create_session() as sess:
count = mark_assets_missing_outside_prefixes(sess, prefixes)
sess.commit()
return count
except Exception as e:
logging.exception("marking missing assets failed: %s", e)
return 0
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots."""
paths: list[str] = []
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
return paths
def build_asset_specs(
paths: list[str],
existing_paths: set[str],
enable_metadata_extraction: bool = True,
compute_hashes: bool = False,
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
enable_metadata_extraction: If True, extract tier 1 & 2 metadata from files
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
metadata = None
if enable_metadata_extraction:
metadata = extract_file_metadata(
abs_p,
stat_result=stat_p,
enable_safetensors=True,
relative_filename=rel_fname,
)
# Compute hash if requested
asset_hash: str | None = None
if compute_hashes:
try:
digest = compute_blake3_hash(abs_p)
asset_hash = "blake3:" + digest
except Exception as e:
logging.warning("Failed to hash %s: %s", abs_p, e)
mime_type = metadata.content_type if metadata else None
if mime_type is None:
pass
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": rel_fname,
"metadata": metadata,
"hash": asset_hash,
"mime_type": mime_type,
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created infos."""
if not specs:
return 0
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_infos
def seed_assets(
roots: tuple[RootType, ...],
enable_logging: bool = False,
compute_hashes: bool = False,
) -> None:
"""Scan the given roots and seed the assets into the database.
Args:
roots: Tuple of root types to scan (models, input, output)
enable_logging: If True, log progress and completion messages
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
Note: This function does not mark missing assets. Call mark_missing_outside_prefixes_safely
separately if cleanup is needed.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
existing_paths: set[str] = set()
for r in roots:
existing_paths.update(sync_root_safely(r))
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
paths = collect_paths_for_roots(roots)
specs, tag_pool, skipped_existing = build_asset_specs(
paths, existing_paths, compute_hashes=compute_hashes
)
created = insert_asset_specs(specs, tag_pool)
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
len(paths),
)

433
app/assets/seeder.py Normal file
View File

@@ -0,0 +1,433 @@
"""Background asset seeder with thread management and cancellation support."""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable
from app.assets.scanner import (
RootType,
build_asset_specs,
collect_paths_for_roots,
get_all_known_prefixes,
get_prefixes_for_root,
insert_asset_specs,
mark_missing_outside_prefixes_safely,
sync_root_safely,
)
from app.database.db import dependencies_available
class State(Enum):
"""Seeder state machine states."""
IDLE = "IDLE"
RUNNING = "RUNNING"
CANCELLING = "CANCELLING"
@dataclass
class Progress:
"""Progress information for a scan operation."""
scanned: int = 0
total: int = 0
created: int = 0
skipped: int = 0
@dataclass
class ScanStatus:
"""Current status of the asset seeder."""
state: State
progress: Progress | None
errors: list[str] = field(default_factory=list)
ProgressCallback = Callable[[Progress], None]
class AssetSeeder:
"""Singleton class managing background asset scanning.
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
"""
_instance: "AssetSeeder | None" = None
_instance_lock = threading.Lock()
def __new__(cls) -> "AssetSeeder":
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._roots: tuple[RootType, ...] = ()
self._compute_hashes: bool = False
self._progress_callback: ProgressCallback | None = None
def start(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
compute_hashes: bool = False,
) -> bool:
"""Start a background scan for the given roots.
Args:
roots: Tuple of root types to scan (models, input, output)
progress_callback: Optional callback called with progress updates
prune_first: If True, prune orphaned assets before scanning
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
Returns:
True if scan was started, False if already running
"""
with self._lock:
if self._state != State.IDLE:
return False
self._state = State.RUNNING
self._progress = Progress()
self._errors = []
self._roots = roots
self._prune_first = prune_first
self._compute_hashes = compute_hashes
self._progress_callback = progress_callback
self._cancel_event.clear()
self._thread = threading.Thread(
target=self._run_scan,
name="AssetSeeder",
daemon=True,
)
self._thread.start()
return True
def cancel(self) -> bool:
"""Request cancellation of the current scan.
Returns:
True if cancellation was requested, False if not running
"""
with self._lock:
if self._state != State.RUNNING:
return False
self._state = State.CANCELLING
self._cancel_event.set()
return True
def wait(self, timeout: float | None = None) -> bool:
"""Wait for the current scan to complete.
Args:
timeout: Maximum seconds to wait, or None for no timeout
Returns:
True if scan completed, False if timeout expired or no scan running
"""
with self._lock:
thread = self._thread
if thread is None:
return True
thread.join(timeout=timeout)
return not thread.is_alive()
def get_status(self) -> ScanStatus:
"""Get the current status and progress of the seeder."""
with self._lock:
return ScanStatus(
state=self._state,
progress=Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if self._progress
else None,
errors=list(self._errors),
)
def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown: cancel any running scan and wait for thread.
Args:
timeout: Maximum seconds to wait for thread to exit
"""
self.cancel()
self.wait(timeout=timeout)
with self._lock:
self._thread = None
def mark_missing_outside_prefixes(self) -> int:
"""Mark cache states as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but cache states are flagged as missing.
They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans
from accidentally marking assets belonging to other roots.
Should be called explicitly when cleanup is desired, typically after
a full scan of all roots or during maintenance.
Returns:
Number of cache states marked as missing, or 0 if dependencies
unavailable or a scan is currently running
"""
with self._lock:
if self._state != State.IDLE:
logging.warning(
"Cannot mark missing assets while scan is running"
)
return 0
self._state = State.RUNNING
try:
if not dependencies_available():
logging.warning(
"Database dependencies not available, skipping mark missing"
)
return 0
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d cache states as missing", marked)
return marked
finally:
with self._lock:
self._state = State.IDLE
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
return self._cancel_event.is_set()
def _emit_event(self, event_type: str, data: dict) -> None:
"""Emit a WebSocket event if server is available."""
try:
from server import PromptServer
if hasattr(PromptServer, "instance") and PromptServer.instance:
PromptServer.instance.send_sync(event_type, data)
except Exception:
pass
def _update_progress(
self,
scanned: int | None = None,
total: int | None = None,
created: int | None = None,
skipped: int | None = None,
) -> None:
"""Update progress counters (thread-safe)."""
callback: ProgressCallback | None = None
progress: Progress | None = None
with self._lock:
if self._progress is None:
return
if scanned is not None:
self._progress.scanned = scanned
if total is not None:
self._progress.total = total
if created is not None:
self._progress.created = created
if skipped is not None:
self._progress.skipped = skipped
if self._progress_callback:
callback = self._progress_callback
progress = Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if callback and progress:
try:
callback(progress)
except Exception:
pass
def _add_error(self, message: str) -> None:
"""Add an error message (thread-safe)."""
with self._lock:
self._errors.append(message)
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
"""Log the directories that will be scanned."""
import folder_paths
for root in roots:
if root == "models":
logging.info(
"Asset scan [models] directory: %s",
os.path.abspath(folder_paths.models_dir),
)
else:
prefixes = get_prefixes_for_root(root)
if prefixes:
logging.info("Asset scan [%s] directories: %s", root, prefixes)
def _run_scan(self) -> None:
"""Main scan loop running in background thread."""
t_start = time.perf_counter()
roots = self._roots
cancelled = False
total_created = 0
skipped_existing = 0
total_paths = 0
try:
if not dependencies_available():
self._add_error("Database dependencies not available")
self._emit_event(
"assets.seed.error",
{"message": "Database dependencies not available"},
)
return
if self._prune_first:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d cache states as missing before scan", marked)
if self._is_cancelled():
logging.info("Asset scan cancelled after pruning phase")
cancelled = True
return
self._log_scan_config(roots)
existing_paths: set[str] = set()
for r in roots:
if self._is_cancelled():
logging.info("Asset scan cancelled during sync phase")
cancelled = True
return
existing_paths.update(sync_root_safely(r))
if self._is_cancelled():
logging.info("Asset scan cancelled after sync phase")
cancelled = True
return
paths = collect_paths_for_roots(roots)
total_paths = len(paths)
self._update_progress(total=total_paths)
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "total": total_paths},
)
specs, tag_pool, skipped_existing = build_asset_specs(
paths, existing_paths, compute_hashes=self._compute_hashes
)
self._update_progress(skipped=skipped_existing)
if self._is_cancelled():
logging.info("Asset scan cancelled after building specs")
cancelled = True
return
batch_size = 500
last_progress_time = time.perf_counter()
progress_interval = 1.0
for i in range(0, len(specs), batch_size):
if self._is_cancelled():
logging.info(
"Asset scan cancelled after %d/%d files (created=%d)",
i,
len(specs),
total_created,
)
cancelled = True
return
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
logging.exception("Batch insert failed at offset %d", i)
scanned = i + len(batch)
now = time.perf_counter()
self._update_progress(scanned=scanned, created=total_created)
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
)
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
elapsed = time.perf_counter() - t_start
logging.info(
"Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, total=%d)",
roots,
elapsed,
total_created,
skipped_existing,
len(paths),
)
self._emit_event(
"assets.seed.completed",
{
"scanned": len(specs),
"total": total_paths,
"created": total_created,
"skipped": skipped_existing,
"elapsed": round(elapsed, 3),
},
)
except Exception as e:
self._add_error(f"Scan failed: {e}")
logging.exception("Asset scan failed")
self._emit_event("assets.seed.error", {"message": str(e)})
finally:
if cancelled:
self._emit_event(
"assets.seed.cancelled",
{
"scanned": self._progress.scanned if self._progress else 0,
"total": total_paths,
"created": total_created,
},
)
with self._lock:
self._state = State.IDLE
asset_seeder = AssetSeeder()

View File

@@ -0,0 +1,89 @@
from app.assets.services.asset_management import (
asset_exists,
delete_asset_reference,
get_asset_by_hash,
get_asset_detail,
list_assets_page,
resolve_asset_for_download,
set_asset_preview,
update_asset_metadata,
)
from app.assets.services.bulk_ingest import (
BulkInsertResult,
batch_insert_seed_assets,
cleanup_unreferenced_assets,
mark_assets_missing_outside_prefixes,
)
from app.assets.services.file_utils import (
get_mtime_ns,
get_size_and_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
upload_from_temp_path,
)
from app.assets.services.schemas import (
AddTagsResult,
AssetData,
AssetDetailResult,
AssetInfoData,
AssetSummaryData,
DownloadResolutionResult,
IngestResult,
ListAssetsResult,
RegisterAssetResult,
RemoveTagsResult,
SetTagsResult,
TagUsage,
UploadResult,
UserMetadata,
)
from app.assets.services.tagging import (
apply_tags,
list_tags,
remove_tags,
)
__all__ = [
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetInfoData",
"AssetSummaryData",
"BulkInsertResult",
"DependencyMissingError",
"DownloadResolutionResult",
"HashMismatchError",
"IngestResult",
"ListAssetsResult",
"RegisterAssetResult",
"RemoveTagsResult",
"SetTagsResult",
"TagUsage",
"UploadResult",
"UserMetadata",
"apply_tags",
"asset_exists",
"batch_insert_seed_assets",
"create_from_hash",
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"get_mtime_ns",
"get_size_and_mtime_ns",
"list_assets_page",
"list_files_recursively",
"list_tags",
"cleanup_unreferenced_assets",
"mark_assets_missing_outside_prefixes",
"remove_tags",
"resolve_asset_for_download",
"set_asset_preview",
"update_asset_metadata",
"upload_from_temp_path",
"verify_file_unchanged",
]

View File

@@ -0,0 +1,292 @@
import contextlib
import mimetypes
import os
from typing import Sequence
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_asset_info_by_id,
list_asset_infos_page,
list_cache_states_by_asset_id,
set_asset_info_metadata,
set_asset_info_preview,
set_asset_info_tags,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_filename_for_asset
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
ListAssetsResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
from app.database.db import create_session
def get_asset_detail(
asset_info_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
with create_session() as session:
result = fetch_asset_info_asset_and_tags(
session,
asset_info_id=asset_info_id,
owner_id=owner_id,
)
if not result:
return None
info, asset, tags = result
return AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tags,
)
def update_asset_metadata(
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info.owner_id and info.owner_id != owner_id:
raise PermissionError("not owner")
touched = False
if name is not None and name != info.name:
update_asset_info_name(session, asset_info_id=asset_info_id, name=name)
touched = True
computed_filename = compute_filename_for_asset(session, info.asset_id)
new_meta: dict | None = None
if user_metadata is not None:
new_meta = dict(user_metadata)
elif computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
if new_meta is not None:
if computed_filename:
new_meta["filename"] = computed_filename
set_asset_info_metadata(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
update_asset_info_updated_at(session, asset_info_id=asset_info_id)
result = fetch_asset_info_asset_and_tags(
session,
asset_info_id=asset_info_id,
owner_id=owner_id,
)
if not result:
raise RuntimeError("State changed during update")
info, asset, tag_list = result
detail = AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_list,
)
session.commit()
return detail
def delete_asset_reference(
asset_info_id: str,
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
# Orphaned asset - delete it and its files
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [
s.file_path for s in (states or []) if getattr(s, "file_path", None)
]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
# Delete files after commit
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def set_asset_preview(
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
result = fetch_asset_info_asset_and_tags(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not result:
raise RuntimeError("State changed during preview update")
info, asset, tags = result
detail = AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tags,
)
session.commit()
return detail
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
return extract_asset_data(asset)
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> ListAssetsResult:
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
items: list[AssetSummaryData] = []
for info in infos:
items.append(
AssetSummaryData(
info=extract_info_data(info),
asset=extract_asset_data(info.asset),
tags=tag_map.get(info.id, []),
)
)
return ListAssetsResult(items=items, total=total)
def resolve_asset_for_download(
asset_info_id: str,
owner_id: str = "",
) -> DownloadResolutionResult:
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(states)
if not abs_path:
raise FileNotFoundError(
f"No live path for AssetInfo {asset_info_id} (asset id={asset.id}, name={info.name})"
)
update_asset_info_access_time(session, asset_info_id=asset_info_id)
session.commit()
ctype = (
asset.mime_type
or mimetypes.guess_type(info.name or abs_path)[0]
or "application/octet-stream"
)
download_name = info.name or os.path.basename(abs_path)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=download_name,
)

View File

@@ -0,0 +1,338 @@
from __future__ import annotations
import logging
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
from app.assets.database.queries import (
bulk_insert_asset_infos_ignore_conflicts,
bulk_insert_assets,
bulk_insert_cache_states_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
get_asset_info_ids_by_ids,
get_cache_states_by_paths_and_asset_ids,
get_existing_asset_ids,
get_unreferenced_unhashed_asset_ids,
mark_cache_states_missing_outside_prefixes,
restore_cache_states_by_paths,
)
from app.assets.helpers import get_utc_now
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
class SeedAssetSpec(TypedDict):
"""Spec for seeding an asset from filesystem."""
abs_path: str
size_bytes: int
mtime_ns: int
info_name: str
tags: list[str]
fname: str
metadata: ExtractedMetadata | None
hash: str | None
mime_type: str | None
class AssetRow(TypedDict):
"""Row data for inserting an Asset."""
id: str
hash: str | None
size_bytes: int
mime_type: str | None
created_at: datetime
class CacheStateRow(TypedDict):
"""Row data for inserting a CacheState."""
asset_id: str
file_path: str
mtime_ns: int
class AssetInfoRow(TypedDict):
"""Row data for inserting an AssetInfo."""
id: str
owner_id: str
name: str
asset_id: str
preview_id: str | None
user_metadata: dict[str, Any] | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
class AssetInfoRowInternal(TypedDict):
"""Internal row data for AssetInfo with extra tracking fields."""
id: str
owner_id: str
name: str
asset_id: str
preview_id: str | None
user_metadata: dict[str, Any] | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
_tags: list[str]
_filename: str
_extracted_metadata: ExtractedMetadata | None
class TagRow(TypedDict):
"""Row data for inserting a Tag."""
asset_info_id: str
tag_name: str
origin: str
added_at: datetime
class MetadataRow(TypedDict):
"""Row data for inserting asset metadata."""
asset_info_id: str
key: str
ordinal: int
val_str: str | None
val_num: float | None
val_bool: bool | None
val_json: dict[str, Any] | None
@dataclass
class BulkInsertResult:
"""Result of bulk asset insertion."""
inserted_infos: int
won_states: int
lost_states: int
def batch_insert_seed_assets(
session: Session,
specs: list[SeedAssetSpec],
owner_id: str = "",
) -> BulkInsertResult:
"""Seed assets from filesystem specs in batch.
Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim cache states with ON CONFLICT DO NOTHING
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert AssetInfo for winners
6. Insert tags and metadata for successfully inserted AssetInfos
Returns:
BulkInsertResult with inserted_infos, won_states, lost_states
"""
if not specs:
return BulkInsertResult(inserted_infos=0, won_states=0, lost_states=0)
current_time = get_utc_now()
asset_rows: list[AssetRow] = []
cache_state_rows: list[CacheStateRow] = []
path_to_asset_id: dict[str, str] = {}
asset_id_to_info: dict[str, AssetInfoRowInternal] = {}
absolute_path_list: list[str] = []
for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"])
asset_id = str(uuid.uuid4())
asset_info_id = str(uuid.uuid4())
absolute_path_list.append(absolute_path)
path_to_asset_id[absolute_path] = asset_id
mime_type = spec.get("mime_type")
if mime_type is None:
logging.info("batch_insert_seed_assets: no mime_type for %s", absolute_path)
asset_rows.append(
{
"id": asset_id,
"hash": spec.get("hash"),
"size_bytes": spec["size_bytes"],
"mime_type": mime_type,
"created_at": current_time,
}
)
cache_state_rows.append(
{
"asset_id": asset_id,
"file_path": absolute_path,
"mtime_ns": spec["mtime_ns"],
}
)
# Build user_metadata from extracted metadata or fallback to filename
extracted_metadata = spec.get("metadata")
if extracted_metadata:
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
elif spec["fname"]:
user_metadata = {"filename": spec["fname"]}
else:
user_metadata = None
asset_id_to_info[asset_id] = {
"id": asset_info_id,
"owner_id": owner_id,
"name": spec["info_name"],
"asset_id": asset_id,
"preview_id": None,
"user_metadata": user_metadata,
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,
"_tags": spec["tags"],
"_filename": spec["fname"],
"_extracted_metadata": extracted_metadata,
}
bulk_insert_assets(session, asset_rows)
# Filter cache states to only those whose assets were actually inserted
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
inserted_asset_ids = get_existing_asset_ids(
session, [r["asset_id"] for r in cache_state_rows]
)
cache_state_rows = [
r for r in cache_state_rows if r["asset_id"] in inserted_asset_ids
]
bulk_insert_cache_states_ignore_conflicts(session, cache_state_rows)
restore_cache_states_by_paths(session, absolute_path_list)
winning_paths = get_cache_states_by_paths_and_asset_ids(session, path_to_asset_id)
all_paths_set = set(absolute_path_list)
losing_paths = all_paths_set - winning_paths
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
if lost_asset_ids:
delete_assets_by_ids(session, lost_asset_ids)
if not winning_paths:
return BulkInsertResult(
inserted_infos=0,
won_states=0,
lost_states=len(losing_paths),
)
winner_info_rows = [
asset_id_to_info[path_to_asset_id[path]] for path in winning_paths
]
database_info_rows: list[AssetInfoRow] = [
{
"id": info_row["id"],
"owner_id": info_row["owner_id"],
"name": info_row["name"],
"asset_id": info_row["asset_id"],
"preview_id": info_row["preview_id"],
"user_metadata": info_row["user_metadata"],
"created_at": info_row["created_at"],
"updated_at": info_row["updated_at"],
"last_access_time": info_row["last_access_time"],
}
for info_row in winner_info_rows
]
bulk_insert_asset_infos_ignore_conflicts(session, database_info_rows)
all_info_ids = [info_row["id"] for info_row in winner_info_rows]
inserted_info_ids = get_asset_info_ids_by_ids(session, all_info_ids)
tag_rows: list[TagRow] = []
metadata_rows: list[MetadataRow] = []
if inserted_info_ids:
for info_row in winner_info_rows:
info_id = info_row["id"]
if info_id not in inserted_info_ids:
continue
for tag in info_row["_tags"]:
tag_rows.append(
{
"asset_info_id": info_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time,
}
)
# Use extracted metadata for meta rows if available
extracted_metadata = info_row.get("_extracted_metadata")
if extracted_metadata:
metadata_rows.extend(extracted_metadata.to_meta_rows(info_id))
elif info_row["_filename"]:
# Fallback: just store filename
metadata_rows.append(
{
"asset_info_id": info_id,
"key": "filename",
"ordinal": 0,
"val_str": info_row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(
inserted_infos=len(inserted_info_ids),
won_states=len(winning_paths),
lost_states=len(losing_paths),
)
def mark_assets_missing_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Mark cache states as missing when outside valid prefixes.
This is a non-destructive operation that soft-deletes cache states
by setting is_missing=True. User metadata is preserved and assets
can be restored if the file reappears in a future scan.
Note: This does NOT delete
unreferenced unhashed assets. Those are preserved so user metadata
remains intact even when base directories change.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states marked as missing
"""
return mark_cache_states_missing_outside_prefixes(session, valid_prefixes)
def cleanup_unreferenced_assets(session: Session) -> int:
"""Hard-delete unhashed assets with no active cache states.
This is a destructive operation intended for explicit cleanup.
Only deletes assets where hash=None and all cache states are missing.
Returns:
Number of assets deleted
"""
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
return delete_assets_by_ids(session, unreferenced_ids)

View File

@@ -0,0 +1,58 @@
import os
def get_mtime_ns(stat_result: os.stat_result) -> int:
"""Extract mtime in nanoseconds from a stat result."""
return getattr(
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
)
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
"""Get file size in bytes and mtime in nanoseconds."""
st = os.stat(path, follow_symlinks=follow_symlinks)
return st.st_size, get_mtime_ns(st)
def verify_file_unchanged(
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
"""Check if a file is unchanged based on mtime and size.
Returns True if the file's mtime and size match the database values.
Returns False if mtime_db is None or values don't match.
size_db=None means don't check size; 0 is a valid recorded size.
"""
if mtime_db is None:
return False
actual_mtime_ns = get_mtime_ns(stat_result)
if int(mtime_db) != int(actual_mtime_ns):
return False
if size_db is not None:
return int(stat_result.st_size) == int(size_db)
return True
def is_visible(name: str) -> bool:
"""Return True if a file or directory name is visible (not hidden)."""
return not name.startswith(".")
def list_files_recursively(base_dir: str) -> list[str]:
"""Recursively list all files in a directory."""
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, subdirs, filenames in os.walk(
base_abs, topdown=True, followlinks=False
):
subdirs[:] = [d for d in subdirs if is_visible(d)]
for name in filenames:
if not is_visible(name):
continue
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out

View File

@@ -0,0 +1,67 @@
import asyncio
import os
from typing import IO
DEFAULT_CHUNK = 8 * 1024 * 1024
_blake3 = None
def _get_blake3():
global _blake3
if _blake3 is None:
try:
from blake3 import blake3 as _b3
_blake3 = _b3
except ImportError:
raise ImportError(
"blake3 is required for asset hashing. Install with: pip install blake3"
)
return _blake3
def compute_blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def compute_blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = file_obj.tell()
try:
if orig_pos != 0:
file_obj.seek(0)
h = _get_blake3()()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -0,0 +1,388 @@
import contextlib
import logging
import mimetypes
import os
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.models import Asset, AssetInfo, Tag
from app.assets.database.queries import (
add_tags_to_asset_info,
fetch_asset_info_and_asset,
get_asset_by_hash,
get_asset_tags,
get_or_create_asset_info,
remove_missing_tag_for_asset_id,
set_asset_info_metadata,
set_asset_info_tags,
update_asset_info_timestamps,
upsert_asset,
upsert_cache_state,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_filename_for_asset,
resolve_destination_from_tags,
validate_path_within_base,
)
from app.assets.services.schemas import (
IngestResult,
RegisterAssetResult,
UploadResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
from app.database.db import create_session
def _ingest_file_from_path(
abs_path: str,
asset_hash: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: UserMetadata = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
asset_created = False
asset_updated = False
state_created = False
state_updated = False
asset_info_id: str | None = None
with create_session() as session:
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=size_bytes,
mime_type=mime_type,
)
state_created, state_updated = upsert_cache_state(
session,
asset_id=asset.id,
file_path=locator,
mtime_ns=mtime_ns,
)
if info_name:
info, info_created = get_or_create_asset_info(
session,
asset_id=asset.id,
owner_id=owner_id,
name=info_name,
preview_id=preview_id,
)
if info_created:
asset_info_id = info.id
else:
update_asset_info_timestamps(
session, asset_info=info, preview_id=preview_id
)
asset_info_id = info.id
norm = normalize_tags(list(tags))
if norm and asset_info_id:
if require_existing_tags:
_validate_tags_exist(session, norm)
add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
if asset_info_id:
_update_metadata_with_filename(
session,
asset_info_id=asset_info_id,
asset_id=asset.id,
info=info,
user_metadata=user_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
session.commit()
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
state_created=state_created,
state_updated=state_updated,
asset_info_id=asset_info_id,
)
def _register_existing_asset(
asset_hash: str,
name: str,
user_metadata: UserMetadata = None,
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> RegisterAssetResult:
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
info, info_created = get_or_create_asset_info(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
preview_id=None,
)
if not info_created:
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = RegisterAssetResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created=False,
)
session.commit()
return result
new_meta = dict(user_metadata or {})
computed_filename = compute_filename_for_asset(session, asset.id)
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
set_asset_info_metadata(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.refresh(info)
result = RegisterAssetResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created=True,
)
session.commit()
return result
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def _update_metadata_with_filename(
session: Session,
asset_info_id: str,
asset_id: str,
info: AssetInfo,
user_metadata: UserMetadata,
) -> None:
computed_filename = compute_filename_for_asset(session, asset_id)
current_meta = info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
set_asset_info_metadata(
session,
asset_info_id=asset_info_id,
user_metadata=new_meta,
)
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback
class HashMismatchError(Exception):
pass
class DependencyMissingError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)
def upload_from_temp_path(
temp_path: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
) -> UploadResult:
try:
digest = hashing.compute_blake3_hash(temp_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_hash and asset_hash != expected_hash.strip().lower():
raise HashMismatchError("Uploaded file hash does not match provided hash.")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _sanitize_filename(name or client_filename, fallback=digest)
result = _register_existing_asset(
asset_hash=asset_hash,
name=display_name,
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
info=result.info,
asset=result.asset,
tags=result.tags,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
ingest_result = _ingest_file_from_path(
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = ingest_result.asset_info_id
if not info_id:
raise RuntimeError("failed to create asset metadata")
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=info_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
return UploadResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> UploadResult | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
info=result.info,
asset=result.asset,
tags=result.tags,
created_new=False,
)

View File

@@ -0,0 +1,338 @@
"""Metadata extraction for asset scanning.
Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only)
"""
from __future__ import annotations
import json
import logging
import mimetypes
import os
import struct
from dataclasses import dataclass
from typing import Any
# Supported safetensors extensions
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
# Maximum safetensors header size to read (8MB)
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
def _register_custom_mime_types():
"""Register custom MIME types for model and config files.
Called before each use because mimetypes.init() in server.py resets the database.
Uses a quick check to avoid redundant registrations.
"""
# Quick check if already registered (avoids redundant add_type calls)
test_result, _ = mimetypes.guess_type("test.safetensors")
if test_result == "application/safetensors":
return
mimetypes.add_type("application/safetensors", ".safetensors")
mimetypes.add_type("application/safetensors", ".sft")
mimetypes.add_type("application/pytorch", ".pt")
mimetypes.add_type("application/pytorch", ".pth")
mimetypes.add_type("application/pickle", ".ckpt")
mimetypes.add_type("application/pickle", ".pkl")
mimetypes.add_type("application/gguf", ".gguf")
mimetypes.add_type("application/yaml", ".yaml")
mimetypes.add_type("application/yaml", ".yml")
# Register custom types at module load
_register_custom_mime_types()
@dataclass
class ExtractedMetadata:
"""Metadata extracted from a file during scanning."""
# Tier 1: Filesystem (always available)
filename: str = ""
content_length: int = 0
content_type: str | None = None
format: str = "" # file extension without dot
# Tier 2: Safetensors header (if available)
base_model: str | None = None
trained_words: list[str] | None = None
air: str | None = None # CivitAI AIR identifier
has_preview_images: bool = False
# Source provenance (populated if embedded in safetensors)
source_url: str | None = None
source_arn: str | None = None
repo_url: str | None = None
preview_url: str | None = None
source_hash: str | None = None
# HuggingFace specific
repo_id: str | None = None
revision: str | None = None
filepath: str | None = None
resolve_url: str | None = None
def to_user_metadata(self) -> dict[str, Any]:
"""Convert to user_metadata dict for AssetInfo.user_metadata JSON field."""
data: dict[str, Any] = {
"filename": self.filename,
"content_length": self.content_length,
"format": self.format,
}
if self.content_type:
data["content_type"] = self.content_type
# Tier 2 fields
if self.base_model:
data["base_model"] = self.base_model
if self.trained_words:
data["trained_words"] = self.trained_words
if self.air:
data["air"] = self.air
if self.has_preview_images:
data["has_preview_images"] = True
# Source provenance
if self.source_url:
data["source_url"] = self.source_url
if self.source_arn:
data["source_arn"] = self.source_arn
if self.repo_url:
data["repo_url"] = self.repo_url
if self.preview_url:
data["preview_url"] = self.preview_url
if self.source_hash:
data["source_hash"] = self.source_hash
# HuggingFace
if self.repo_id:
data["repo_id"] = self.repo_id
if self.revision:
data["revision"] = self.revision
if self.filepath:
data["filepath"] = self.filepath
if self.resolve_url:
data["resolve_url"] = self.resolve_url
return data
def to_meta_rows(self, asset_info_id: str) -> list[dict]:
"""Convert to asset_info_meta rows for typed/indexed querying."""
rows: list[dict] = []
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
if val:
rows.append({
"asset_info_id": asset_info_id,
"key": key,
"ordinal": ordinal,
"val_str": val[:2048] if len(val) > 2048 else val,
"val_num": None,
"val_bool": None,
"val_json": None,
})
def add_num(key: str, val: int | float | None) -> None:
if val is not None:
rows.append({
"asset_info_id": asset_info_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": val,
"val_bool": None,
"val_json": None,
})
def add_bool(key: str, val: bool | None) -> None:
if val is not None:
rows.append({
"asset_info_id": asset_info_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": None,
"val_bool": val,
"val_json": None,
})
# Tier 1
add_str("filename", self.filename)
add_num("content_length", self.content_length)
add_str("content_type", self.content_type)
add_str("format", self.format)
# Tier 2
add_str("base_model", self.base_model)
add_str("air", self.air)
add_bool("has_preview_images", self.has_preview_images if self.has_preview_images else None)
# trained_words as multiple rows with ordinals
if self.trained_words:
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
add_str("trained_words", word, ordinal=i)
# Source provenance
add_str("source_url", self.source_url)
add_str("source_arn", self.source_arn)
add_str("repo_url", self.repo_url)
add_str("preview_url", self.preview_url)
add_str("source_hash", self.source_hash)
# HuggingFace
add_str("repo_id", self.repo_id)
add_str("revision", self.revision)
add_str("filepath", self.filepath)
add_str("resolve_url", self.resolve_url)
return rows
def _read_safetensors_header(path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE) -> dict[str, Any] | None:
"""Read only the JSON header from a safetensors file.
This is very fast - reads 8 bytes for header length, then the JSON header.
No tensor data is loaded.
Args:
path: Absolute path to safetensors file
max_size: Maximum header size to read (default 8MB)
Returns:
Parsed header dict or None if failed
"""
try:
with open(path, "rb") as f:
header_bytes = f.read(8)
if len(header_bytes) < 8:
return None
length_of_header = struct.unpack("<Q", header_bytes)[0]
if length_of_header > max_size:
return None
header_data = f.read(length_of_header)
if len(header_data) < length_of_header:
return None
return json.loads(header_data.decode("utf-8"))
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
return None
def _extract_safetensors_metadata(header: dict[str, Any], meta: ExtractedMetadata) -> None:
"""Extract metadata from safetensors header __metadata__ section.
Modifies meta in-place.
"""
st_meta = header.get("__metadata__", {})
if not isinstance(st_meta, dict):
return
# Common model metadata
meta.base_model = st_meta.get("ss_base_model_version") or st_meta.get("modelspec.base_model") or st_meta.get("base_model")
# Trained words / trigger words
trained_words = st_meta.get("ss_tag_frequency")
if trained_words and isinstance(trained_words, str):
try:
tag_freq = json.loads(trained_words)
# Extract unique tags from all datasets
all_tags: set[str] = set()
for dataset_tags in tag_freq.values():
if isinstance(dataset_tags, dict):
all_tags.update(dataset_tags.keys())
if all_tags:
meta.trained_words = sorted(all_tags)[:100]
except json.JSONDecodeError:
pass
# Direct trained_words field (some formats)
if not meta.trained_words:
tw = st_meta.get("trained_words")
if isinstance(tw, str):
try:
meta.trained_words = json.loads(tw)
except json.JSONDecodeError:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
elif isinstance(tw, list):
meta.trained_words = tw
# CivitAI AIR
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
# Preview images (ssmd_cover_images)
cover_images = st_meta.get("ssmd_cover_images")
if cover_images:
meta.has_preview_images = True
# Source provenance fields
meta.source_url = st_meta.get("source_url")
meta.source_arn = st_meta.get("source_arn")
meta.repo_url = st_meta.get("repo_url")
meta.preview_url = st_meta.get("preview_url")
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
# HuggingFace fields
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
def extract_file_metadata(
abs_path: str,
stat_result: os.stat_result | None = None,
enable_safetensors: bool = True,
relative_filename: str | None = None,
) -> ExtractedMetadata:
"""Extract metadata from a file using tier 1 and optionally tier 2 methods.
Tier 1 (always): Filesystem metadata from path and stat
Tier 2 (optional): Safetensors header parsing if applicable
Args:
abs_path: Absolute path to the file
stat_result: Optional pre-fetched stat result (saves a syscall)
enable_safetensors: Whether to parse safetensors headers (tier 2)
relative_filename: Optional relative filename to use instead of basename
(e.g., "flux/123/model.safetensors" for model paths)
Returns:
ExtractedMetadata with all available fields populated
"""
meta = ExtractedMetadata()
# Tier 1: Filesystem metadata
# Use relative_filename if provided (for backward compatibility with existing behavior)
meta.filename = relative_filename if relative_filename else os.path.basename(abs_path)
_, ext = os.path.splitext(abs_path)
meta.format = ext.lstrip(".").lower() if ext else ""
# MIME type guess (re-register in case mimetypes.init() was called elsewhere)
_register_custom_mime_types()
mime_type, _ = mimetypes.guess_type(abs_path)
meta.content_type = mime_type
if mime_type is None:
pass
# Size from stat
if stat_result is None:
try:
stat_result = os.stat(abs_path, follow_symlinks=True)
except OSError:
pass
if stat_result:
meta.content_length = stat_result.st_size
# Tier 2: Safetensors header (if applicable and enabled)
if enable_safetensors and ext.lower() in SAFETENSORS_EXTENSIONS:
header = _read_safetensors_header(abs_path)
if header:
try:
_extract_safetensors_metadata(header, meta)
except Exception as e:
logging.debug("Failed to extract safetensors metadata from %s: %s", abs_path, e)
return meta

View File

@@ -0,0 +1,184 @@
import os
from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = (
values[0],
values[1],
) # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory()
if root == "input"
else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def validate_path_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _compute_relative(child: str, parent: str) -> str:
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _check_is_within(fp_abs, input_base):
return "input", _compute_relative(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
)
def compute_filename_for_asset(session, asset_id: str) -> str | None:
"""Compute the relative filename for an asset from its best live cache state path."""
from app.assets.database.queries import list_cache_states_by_asset_id
from app.assets.helpers import select_best_live_path
primary_path = select_best_live_path(
list_cache_states_by_asset_id(session, asset_id=asset_id)
)
return compute_relative_filename(primary_path) if primary_path else None
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_asset_category_and_relative_path`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@@ -0,0 +1,126 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetInfo
UserMetadata = dict[str, Any] | None
@dataclass(frozen=True)
class AssetData:
hash: str
size_bytes: int | None
mime_type: str | None
@dataclass(frozen=True)
class AssetInfoData:
id: str
name: str
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None
@dataclass(frozen=True)
class AssetDetailResult:
info: AssetInfoData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class RegisterAssetResult:
info: AssetInfoData
asset: AssetData
tags: list[str]
created: bool
@dataclass(frozen=True)
class IngestResult:
asset_created: bool
asset_updated: bool
state_created: bool
state_updated: bool
asset_info_id: str | None
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
@dataclass(frozen=True)
class AssetSummaryData:
info: AssetInfoData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
@dataclass(frozen=True)
class DownloadResolutionResult:
abs_path: str
content_type: str
download_name: str
@dataclass(frozen=True)
class UploadResult:
info: AssetInfoData
asset: AssetData
tags: list[str]
created_new: bool
def extract_info_data(info: AssetInfo) -> AssetInfoData:
return AssetInfoData(
id=info.id,
name=info.name,
user_metadata=info.user_metadata,
preview_id=info.preview_id,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
def extract_asset_data(asset: Asset | None) -> AssetData | None:
if asset is None:
return None
return AssetData(
hash=asset.hash,
size_bytes=asset.size_bytes,
mime_type=asset.mime_type,
)

View File

@@ -0,0 +1,89 @@
from app.assets.database.queries import (
add_tags_to_asset_info,
get_asset_info_by_id,
list_tags_with_usage,
remove_tags_from_asset_info,
)
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
from app.database.db import create_session
def apply_tags(
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> AddTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return AddTagsResult(
added=data["added"],
already_present=data["already_present"],
total_tags=data["total_tags"],
)
def remove_tags(
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> RemoveTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return RemoveTagsResult(
removed=data["removed"],
not_present=data["not_present"],
total_tags=data["total_tags"],
)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> tuple[list[TagUsage], int]:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total

View File

@@ -14,7 +14,7 @@ try:
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
@@ -75,6 +75,13 @@ def init_db():
# Check if we need to upgrade
engine = create_engine(db_url)
# Enable foreign key enforcement for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
conn = engine.connect()
context = MigrationContext.configure(conn)

10
main.py
View File

@@ -7,7 +7,7 @@ import folder_paths
import time
from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
from app.assets.scanner import seed_assets
from app.assets.seeder import asset_seeder
import itertools
import utils.extra_config
import logging
@@ -354,7 +354,8 @@ def setup_database():
if dependencies_available():
init_db()
if not args.disable_assets_autoscan:
seed_assets(["models"], enable_logging=True)
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
logging.info("Background asset scan initiated for models, input, output")
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
@@ -438,5 +439,6 @@ if __name__ == "__main__":
event_loop.run_until_complete(x)
except KeyboardInterrupt:
logging.info("\nStopped server")
cleanup_temp()
finally:
asset_seeder.shutdown()
cleanup_temp()

View File

@@ -259,13 +259,3 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
for aid in ids:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@@ -0,0 +1,15 @@
"""Helper functions for assets integration tests."""
import requests
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a synchronous sync/seed pass by calling the seed endpoint with wait=true."""
session.post(
base_url + "/api/assets/seed?wait=true",
json={"roots": ["models", "input", "output"]},
timeout=60,
)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@@ -0,0 +1,20 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def session():
"""In-memory SQLite session for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
with Session(engine) as sess:
yield sess
@pytest.fixture(autouse=True)
def autoclean_unit_test_assets():
"""Override parent autouse fixture - query tests don't need server cleanup."""
yield

View File

@@ -0,0 +1,142 @@
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
)
class TestAssetExistsByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,expected",
[
(None, "nonexistent", False), # No asset exists
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
(None, "", False), # Null hash in DB doesn't match empty string
],
ids=["nonexistent", "existing", "null_hash_no_match"],
)
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
if setup_hash is not None or setup_hash is None:
# Create asset with given hash (including None for null hash test)
if setup_hash is not None or query_hash == "":
asset = Asset(hash=setup_hash, size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
class TestGetAssetByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,should_find",
[
(None, "nonexistent", False),
("blake3:def456", "blake3:def456", True),
],
ids=["nonexistent", "existing"],
)
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
if setup_hash is not None:
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
session.add(asset)
session.commit()
result = get_asset_by_hash(session, asset_hash=query_hash)
if should_find:
assert result is not None
assert result.size_bytes == 200
assert result.mime_type == "image/png"
else:
assert result is None
class TestUpsertAsset:
@pytest.mark.parametrize(
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
[
# New asset creation
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
# Existing asset, same values - no update
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
# Existing asset with size 0, update with new values
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
# Existing asset, second call with size 0 - no update
(1000, None, 0, None, False, False, 1000, None),
],
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
)
def test_upsert_scenarios(
self,
session: Session,
first_size,
first_mime,
second_size,
second_mime,
expect_created,
expect_updated,
final_size,
final_mime,
):
asset_hash = f"blake3:test_{first_size}_{second_size}"
# First upsert (if first_size is not None, we're testing the second call)
if first_size is not None:
upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=first_size,
mime_type=first_mime,
)
session.commit()
# The upsert call we're testing
asset, created, updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=second_size,
mime_type=second_mime,
)
session.commit()
assert created is expect_created
assert updated is expect_updated
assert asset.size_bytes == final_size
assert asset.mime_type == final_mime
class TestBulkInsertAssets:
def test_inserts_multiple_assets(self, session: Session):
rows = [
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": None},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": None},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": None},
]
bulk_insert_assets(session, rows)
session.commit()
assets = session.query(Asset).all()
assert len(assets) == 3
hashes = {a.hash for a in assets}
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
def test_empty_list_is_noop(self, session: Session):
bulk_insert_assets(session, [])
session.commit()
assert session.query(Asset).count() == 0
def test_handles_large_batch(self, session: Session):
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
rows = [
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": None}
for i in range(200)
]
bulk_insert_assets(session, rows)
session.commit()
assert session.query(Asset).count() == 200

View File

@@ -0,0 +1,511 @@
import time
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
from app.assets.database.queries import (
asset_info_exists_for_asset_id,
get_asset_info_by_id,
insert_asset_info,
get_or_create_asset_info,
update_asset_info_timestamps,
list_asset_infos_page,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
update_asset_info_access_time,
set_asset_info_metadata,
delete_asset_info_by_id,
set_asset_info_preview,
bulk_insert_asset_infos_ignore_conflicts,
get_asset_info_ids_by_ids,
ensure_tags_exist,
add_tags_to_asset_info,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestAssetInfoExistsForAssetId:
def test_returns_false_when_no_info(self, session: Session):
asset = _make_asset(session, "hash1")
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is False
def test_returns_true_when_info_exists(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset)
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is True
class TestGetAssetInfoById:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_asset_info_by_id(session, asset_info_id="nonexistent") is None
def test_returns_info(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="myfile.txt")
result = get_asset_info_by_id(session, asset_info_id=info.id)
assert result is not None
assert result.name == "myfile.txt"
class TestListAssetInfosPage:
def test_empty_db(self, session: Session):
infos, tag_map, total = list_asset_infos_page(session)
assert infos == []
assert tag_map == {}
assert total == 0
def test_returns_infos_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
session.commit()
infos, tag_map, total = list_asset_infos_page(session)
assert len(infos) == 1
assert infos[0].id == info.id
assert set(tag_map[info.id]) == {"alpha", "beta"}
assert total == 1
def test_name_contains_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="model_v1.safetensors")
_make_asset_info(session, asset, name="config.json")
session.commit()
infos, _, total = list_asset_infos_page(session, name_contains="model")
assert total == 1
assert infos[0].name == "model_v1.safetensors"
def test_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="public", owner_id="")
_make_asset_info(session, asset, name="private", owner_id="user1")
session.commit()
# Empty owner sees only public
infos, _, total = list_asset_infos_page(session, owner_id="")
assert total == 1
assert infos[0].name == "public"
# Owner sees both
infos, _, total = list_asset_infos_page(session, owner_id="user1")
assert total == 2
def test_include_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="tagged")
_make_asset_info(session, asset, name="untagged")
ensure_tags_exist(session, ["wanted"])
add_tags_to_asset_info(session, asset_info_id=info1.id, tags=["wanted"])
session.commit()
infos, _, total = list_asset_infos_page(session, include_tags=["wanted"])
assert total == 1
assert infos[0].name == "tagged"
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="keep")
info_exclude = _make_asset_info(session, asset, name="exclude")
ensure_tags_exist(session, ["bad"])
add_tags_to_asset_info(session, asset_info_id=info_exclude.id, tags=["bad"])
session.commit()
infos, _, total = list_asset_infos_page(session, exclude_tags=["bad"])
assert total == 1
assert infos[0].name == "keep"
def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500)
_make_asset_info(session, asset, name="small")
_make_asset_info(session, asset2, name="large")
session.commit()
infos, _, _ = list_asset_infos_page(session, sort="size", order="desc")
assert infos[0].name == "large"
infos, _, _ = list_asset_infos_page(session, sort="name", order="asc")
assert infos[0].name == "large"
class TestFetchAssetInfoAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_asset_and_tags(session, "nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["tag1"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["tag1"])
session.commit()
result = fetch_asset_info_asset_and_tags(session, info.id)
assert result is not None
ret_info, ret_asset, ret_tags = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
assert ret_tags == ["tag1"]
class TestFetchAssetInfoAndAsset:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_and_asset(session, asset_info_id="nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = fetch_asset_info_and_asset(session, asset_info_id=info.id)
assert result is not None
ret_info, ret_asset = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
class TestUpdateAssetInfoAccessTime:
def test_updates_last_access_time(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
original_time = info.last_access_time
session.commit()
import time
time.sleep(0.01)
update_asset_info_access_time(session, asset_info_id=info.id)
session.commit()
session.refresh(info)
assert info.last_access_time > original_time
class TestDeleteAssetInfoById:
def test_deletes_existing(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="")
assert result is True
assert get_asset_info_by_id(session, asset_info_id=info.id) is None
def test_returns_false_for_nonexistent(self, session: Session):
result = delete_asset_info_by_id(session, asset_info_id="nonexistent", owner_id="")
assert result is False
def test_respects_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="user2")
assert result is False
assert get_asset_info_by_id(session, asset_info_id=info.id) is not None
class TestSetAssetInfoPreview:
def test_sets_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=preview_asset.id)
session.commit()
session.refresh(info)
assert info.preview_id == preview_asset.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
info.preview_id = preview_asset.id
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=None)
session.commit()
session.refresh(info)
assert info.preview_id is None
def test_raises_for_nonexistent_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_asset_info_preview(session, asset_info_id="nonexistent", preview_asset_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
with pytest.raises(ValueError, match="Preview Asset"):
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id="nonexistent")
class TestInsertAssetInfo:
def test_creates_new_info(self, session: Session):
asset = _make_asset(session, "hash1")
info = insert_asset_info(
session, asset_id=asset.id, owner_id="user1", name="test.bin"
)
session.commit()
assert info is not None
assert info.name == "test.bin"
assert info.owner_id == "user1"
def test_returns_none_on_conflict(self, session: Session):
asset = _make_asset(session, "hash1")
insert_asset_info(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
session.commit()
# Attempt duplicate with same (asset_id, owner_id, name)
result = insert_asset_info(
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
)
assert result is None
class TestGetOrCreateAssetInfo:
def test_creates_new_info(self, session: Session):
asset = _make_asset(session, "hash1")
info, created = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="new.bin"
)
session.commit()
assert created is True
assert info.name == "new.bin"
def test_returns_existing_info(self, session: Session):
asset = _make_asset(session, "hash1")
info1, created1 = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
info2, created2 = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
assert created1 is True
assert created2 is False
assert info1.id == info2.id
class TestUpdateAssetInfoTimestamps:
def test_updates_timestamps(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
original_updated_at = info.updated_at
session.commit()
time.sleep(0.01)
update_asset_info_timestamps(session, info)
session.commit()
session.refresh(info)
assert info.updated_at > original_updated_at
def test_updates_preview_id(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
session.commit()
update_asset_info_timestamps(session, info, preview_id=preview_asset.id)
session.commit()
session.refresh(info)
assert info.preview_id == preview_asset.id
class TestSetAssetInfoMetadata:
def test_sets_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"key": "value"}
)
session.commit()
session.refresh(info)
assert info.user_metadata == {"key": "value"}
# Check metadata table
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "key"
assert meta[0].val_str == "value"
def test_replaces_existing_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"old": "data"}
)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"new": "data"}
)
session.commit()
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "new"
def test_clears_metadata_with_empty_dict(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"key": "value"}
)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={}
)
session.commit()
session.refresh(info)
assert info.user_metadata == {}
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 0
def test_raises_for_nonexistent(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_asset_info_metadata(
session, asset_info_id="nonexistent", user_metadata={"key": "value"}
)
class TestBulkInsertAssetInfosIgnoreConflicts:
def test_inserts_multiple_infos(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk1.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk2.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_asset_infos_ignore_conflicts(session, rows)
session.commit()
infos = session.query(AssetInfo).all()
assert len(infos) == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="existing.bin", owner_id="")
session.commit()
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "existing.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "new.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_asset_infos_ignore_conflicts(session, rows)
session.commit()
infos = session.query(AssetInfo).all()
assert len(infos) == 2 # existing + new, not 3
def test_empty_list_is_noop(self, session: Session):
bulk_insert_asset_infos_ignore_conflicts(session, [])
assert session.query(AssetInfo).count() == 0
class TestGetAssetInfoIdsByIds:
def test_returns_existing_ids(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="a.bin")
info2 = _make_asset_info(session, asset, name="b.bin")
session.commit()
found = get_asset_info_ids_by_ids(session, [info1.id, info2.id, "nonexistent"])
assert found == {info1.id, info2.id}
def test_empty_list_returns_empty(self, session: Session):
found = get_asset_info_ids_by_ids(session, [])
assert found == set()

View File

@@ -0,0 +1,468 @@
"""Tests for cache_state query functions."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries import (
list_cache_states_by_asset_id,
upsert_cache_state,
get_unreferenced_unhashed_asset_ids,
delete_assets_by_ids,
get_cache_states_for_prefixes,
bulk_update_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
bulk_insert_cache_states_ignore_conflicts,
get_cache_states_by_paths_and_asset_ids,
mark_cache_states_missing_outside_prefixes,
restore_cache_states_by_paths,
)
from app.assets.helpers import select_best_live_path, get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size)
session.add(asset)
session.flush()
return asset
def _make_cache_state(
session: Session,
asset: Asset,
file_path: str,
mtime_ns: int | None = None,
needs_verify: bool = False,
) -> AssetCacheState:
state = AssetCacheState(
asset_id=asset.id,
file_path=file_path,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
)
session.add(state)
session.flush()
return state
class TestListCacheStatesByAssetId:
def test_returns_empty_for_no_states(self, session: Session):
asset = _make_asset(session, "hash1")
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
assert list(states) == []
def test_returns_states_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/path/a.bin")
_make_cache_state(session, asset, "/path/b.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
paths = [s.file_path for s in states]
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
def test_does_not_return_other_assets_states(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path/asset1.bin")
_make_cache_state(session, asset2, "/path/asset2.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset1.id)
paths = [s.file_path for s in states]
assert paths == ["/path/asset1.bin"]
class TestSelectBestLivePath:
def test_returns_empty_for_empty_list(self):
result = select_best_live_path([])
assert result == ""
def test_returns_empty_when_no_files_exist(self, session: Session):
asset = _make_asset(session, "hash1")
state = _make_cache_state(session, asset, "/nonexistent/path.bin")
session.commit()
result = select_best_live_path([state])
assert result == ""
def test_prefers_verified_path(self, session: Session, tmp_path):
"""needs_verify=False should be preferred."""
asset = _make_asset(session, "hash1")
verified_file = tmp_path / "verified.bin"
verified_file.write_bytes(b"data")
unverified_file = tmp_path / "unverified.bin"
unverified_file.write_bytes(b"data")
state_verified = _make_cache_state(
session, asset, str(verified_file), needs_verify=False
)
state_unverified = _make_cache_state(
session, asset, str(unverified_file), needs_verify=True
)
session.commit()
states = [state_unverified, state_verified]
result = select_best_live_path(states)
assert result == str(verified_file)
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
"""If all states need verification, return first existing path."""
asset = _make_asset(session, "hash1")
existing_file = tmp_path / "exists.bin"
existing_file.write_bytes(b"data")
state = _make_cache_state(session, asset, str(existing_file), needs_verify=True)
session.commit()
result = select_best_live_path([state])
assert result == str(existing_file)
class TestSelectBestLivePathWithMocking:
def test_handles_missing_file_path_attr(self):
"""Gracefully handle states with None file_path."""
class MockState:
file_path = None
needs_verify = False
result = select_best_live_path([MockState()])
assert result == ""
class TestUpsertCacheState:
@pytest.mark.parametrize(
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
[
# New state creation
(None, 12345, True, False, 12345),
# Existing state, same mtime - no update
(100, 100, False, False, 100),
# Existing state, different mtime - update
(100, 200, False, True, 200),
],
ids=["new_state", "existing_no_change", "existing_update_mtime"],
)
def test_upsert_scenarios(
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
):
asset = _make_asset(session, "hash1")
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
# Create initial state if needed
if initial_mtime is not None:
upsert_cache_state(session, asset_id=asset.id, file_path=file_path, mtime_ns=initial_mtime)
session.commit()
# The upsert call we're testing
created, updated = upsert_cache_state(
session, asset_id=asset.id, file_path=file_path, mtime_ns=second_mtime
)
session.commit()
assert created is expect_created
assert updated is expect_updated
state = session.query(AssetCacheState).filter_by(file_path=file_path).one()
assert state.mtime_ns == final_mtime
def test_upsert_restores_missing_state(self, session: Session):
"""Upserting a cache state that was marked missing should restore it."""
asset = _make_asset(session, "hash1")
file_path = "/restored/file.bin"
state = _make_cache_state(session, asset, file_path, mtime_ns=100)
state.is_missing = True
session.commit()
created, updated = upsert_cache_state(
session, asset_id=asset.id, file_path=file_path, mtime_ns=100
)
session.commit()
assert created is False
assert updated is True
restored_state = session.query(AssetCacheState).filter_by(file_path=file_path).one()
assert restored_state.is_missing is False
class TestRestoreCacheStatesByPaths:
def test_restores_missing_states(self, session: Session):
asset = _make_asset(session, "hash1")
missing_path = "/missing/file.bin"
active_path = "/active/file.bin"
missing_state = _make_cache_state(session, asset, missing_path)
missing_state.is_missing = True
_make_cache_state(session, asset, active_path)
session.commit()
restored = restore_cache_states_by_paths(session, [missing_path])
session.commit()
assert restored == 1
state = session.query(AssetCacheState).filter_by(file_path=missing_path).one()
assert state.is_missing is False
def test_empty_list_restores_nothing(self, session: Session):
restored = restore_cache_states_by_paths(session, [])
assert restored == 0
class TestMarkCacheStatesMissingOutsidePrefixes:
def test_marks_states_missing_outside_prefixes(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
valid_dir = tmp_path / "valid"
valid_dir.mkdir()
invalid_dir = tmp_path / "invalid"
invalid_dir.mkdir()
valid_path = str(valid_dir / "file.bin")
invalid_path = str(invalid_dir / "file.bin")
_make_cache_state(session, asset, valid_path)
_make_cache_state(session, asset, invalid_path)
session.commit()
marked = mark_cache_states_missing_outside_prefixes(session, [str(valid_dir)])
session.commit()
assert marked == 1
all_states = session.query(AssetCacheState).all()
assert len(all_states) == 2
valid_state = next(s for s in all_states if s.file_path == valid_path)
invalid_state = next(s for s in all_states if s.file_path == invalid_path)
assert valid_state.is_missing is False
assert invalid_state.is_missing is True
def test_empty_prefixes_marks_nothing(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/some/path.bin")
session.commit()
marked = mark_cache_states_missing_outside_prefixes(session, [])
assert marked == 0
class TestGetUnreferencedUnhashedAssetIds:
def test_returns_unreferenced_unhashed_assets(self, session: Session):
# Unhashed asset (hash=None) with no cache states
no_states = _make_asset(session, hash_val=None)
# Unhashed asset with active cache state (not unreferenced)
with_active_state = _make_asset(session, hash_val=None)
_make_cache_state(session, with_active_state, "/has/state.bin")
# Unhashed asset with only missing cache state (should be unreferenced)
with_missing_state = _make_asset(session, hash_val=None)
missing_state = _make_cache_state(session, with_missing_state, "/missing/state.bin")
missing_state.is_missing = True
# Regular asset (hash not None) - should not be returned
_make_asset(session, hash_val="blake3:regular")
session.commit()
unreferenced = get_unreferenced_unhashed_asset_ids(session)
assert no_states.id in unreferenced
assert with_missing_state.id in unreferenced
assert with_active_state.id not in unreferenced
class TestDeleteAssetsByIds:
def test_deletes_assets_and_infos(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
info = AssetInfo(
owner_id="", name="test", asset_id=asset.id,
created_at=now, updated_at=now, last_access_time=now
)
session.add(info)
session.commit()
deleted = delete_assets_by_ids(session, [asset.id])
session.commit()
assert deleted == 1
assert session.query(Asset).count() == 0
assert session.query(AssetInfo).count() == 0
def test_empty_list_deletes_nothing(self, session: Session):
_make_asset(session, "hash1")
session.commit()
deleted = delete_assets_by_ids(session, [])
assert deleted == 0
assert session.query(Asset).count() == 1
class TestGetCacheStatesForPrefixes:
def test_returns_states_matching_prefix(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
dir1 = tmp_path / "dir1"
dir1.mkdir()
dir2 = tmp_path / "dir2"
dir2.mkdir()
path1 = str(dir1 / "file.bin")
path2 = str(dir2 / "file.bin")
_make_cache_state(session, asset, path1, mtime_ns=100)
_make_cache_state(session, asset, path2, mtime_ns=200)
session.commit()
rows = get_cache_states_for_prefixes(session, [str(dir1)])
assert len(rows) == 1
assert rows[0].file_path == path1
def test_empty_prefixes_returns_empty(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/some/path.bin")
session.commit()
rows = get_cache_states_for_prefixes(session, [])
assert rows == []
class TestBulkSetNeedsVerify:
def test_sets_needs_verify_flag(self, session: Session):
asset = _make_asset(session, "hash1")
state1 = _make_cache_state(session, asset, "/path1.bin", needs_verify=False)
state2 = _make_cache_state(session, asset, "/path2.bin", needs_verify=False)
session.commit()
updated = bulk_update_needs_verify(session, [state1.id, state2.id], True)
session.commit()
assert updated == 2
session.refresh(state1)
session.refresh(state2)
assert state1.needs_verify is True
assert state2.needs_verify is True
def test_empty_list_updates_nothing(self, session: Session):
updated = bulk_update_needs_verify(session, [], True)
assert updated == 0
class TestDeleteCacheStatesByIds:
def test_deletes_states_by_id(self, session: Session):
asset = _make_asset(session, "hash1")
state1 = _make_cache_state(session, asset, "/path1.bin")
_make_cache_state(session, asset, "/path2.bin")
session.commit()
deleted = delete_cache_states_by_ids(session, [state1.id])
session.commit()
assert deleted == 1
assert session.query(AssetCacheState).count() == 1
def test_empty_list_deletes_nothing(self, session: Session):
deleted = delete_cache_states_by_ids(session, [])
assert deleted == 0
class TestDeleteOrphanedSeedAsset:
@pytest.mark.parametrize(
"create_asset,expected_deleted,expected_count",
[
(True, True, 0), # Existing asset gets deleted
(False, False, 0), # Nonexistent returns False
],
ids=["deletes_existing", "nonexistent_returns_false"],
)
def test_delete_orphaned_seed_asset(
self, session: Session, create_asset, expected_deleted, expected_count
):
asset_id = "nonexistent-id"
if create_asset:
asset = _make_asset(session, hash_val=None)
asset_id = asset.id
now = get_utc_now()
info = AssetInfo(
owner_id="", name="test", asset_id=asset.id,
created_at=now, updated_at=now, last_access_time=now
)
session.add(info)
session.commit()
deleted = delete_orphaned_seed_asset(session, asset_id)
if create_asset:
session.commit()
assert deleted is expected_deleted
assert session.query(Asset).count() == expected_count
class TestBulkInsertCacheStatesIgnoreConflicts:
def test_inserts_multiple_states(self, session: Session):
asset = _make_asset(session, "hash1")
rows = [
{"asset_id": asset.id, "file_path": "/bulk1.bin", "mtime_ns": 100},
{"asset_id": asset.id, "file_path": "/bulk2.bin", "mtime_ns": 200},
]
bulk_insert_cache_states_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetCacheState).count() == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/existing.bin", mtime_ns=100)
session.commit()
rows = [
{"asset_id": asset.id, "file_path": "/existing.bin", "mtime_ns": 999},
{"asset_id": asset.id, "file_path": "/new.bin", "mtime_ns": 200},
]
bulk_insert_cache_states_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetCacheState).count() == 2
existing = session.query(AssetCacheState).filter_by(file_path="/existing.bin").one()
assert existing.mtime_ns == 100 # Original value preserved
def test_empty_list_is_noop(self, session: Session):
bulk_insert_cache_states_ignore_conflicts(session, [])
assert session.query(AssetCacheState).count() == 0
class TestGetCacheStatesByPathsAndAssetIds:
def test_returns_matching_paths(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path1.bin")
_make_cache_state(session, asset2, "/path2.bin")
session.commit()
path_to_asset = {
"/path1.bin": asset1.id,
"/path2.bin": asset2.id,
}
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
assert winners == {"/path1.bin", "/path2.bin"}
def test_excludes_non_matching_asset_ids(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path1.bin")
session.commit()
# Path exists but with different asset_id
path_to_asset = {"/path1.bin": asset2.id}
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
assert winners == set()
def test_empty_dict_returns_empty(self, session: Session):
winners = get_cache_states_by_paths_and_asset_ids(session, {})
assert winners == set()

View File

@@ -0,0 +1,184 @@
"""Tests for metadata filtering logic in asset_info queries."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
from app.assets.database.queries import list_asset_infos_page
from app.assets.database.queries.asset_info import convert_metadata_to_rows
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str,
metadata: dict | None = None,
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id="",
name=name,
asset_id=asset.id,
user_metadata=metadata,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
if metadata:
for key, val in metadata.items():
for row in convert_metadata_to_rows(key, val):
meta_row = AssetInfoMeta(
asset_info_id=info.id,
key=row["key"],
ordinal=row.get("ordinal", 0),
val_str=row.get("val_str"),
val_num=row.get("val_num"),
val_bool=row.get("val_bool"),
val_json=row.get("val_json"),
)
session.add(meta_row)
session.flush()
return info
class TestMetadataFilterByType:
"""Table-driven tests for metadata filtering by different value types."""
@pytest.mark.parametrize(
"match_meta,nomatch_meta,filter_key,filter_val",
[
# String matching
({"category": "models"}, {"category": "images"}, "category", "models"),
# Integer matching
({"epoch": 5}, {"epoch": 10}, "epoch", 5),
# Float matching
({"score": 0.95}, {"score": 0.5}, "score", 0.95),
# Boolean True matching
({"enabled": True}, {"enabled": False}, "enabled", True),
# Boolean False matching
({"enabled": False}, {"enabled": True}, "enabled", False),
],
ids=["string", "int", "float", "bool_true", "bool_false"],
)
def test_filter_matches_correct_value(
self, session: Session, match_meta, nomatch_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", match_meta)
_make_asset_info(session, asset, "nomatch", nomatch_meta)
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 1
assert infos[0].name == "match"
@pytest.mark.parametrize(
"stored_meta,filter_key,filter_val",
[
# String no match
({"category": "models"}, "category", "other"),
# Int no match
({"epoch": 5}, "epoch", 99),
# Float no match
({"score": 0.5}, "score", 0.99),
],
ids=["string_no_match", "int_no_match", "float_no_match"],
)
def test_filter_returns_empty_when_no_match(
self, session: Session, stored_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "item", stored_meta)
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 0
class TestMetadataFilterNull:
"""Tests for null/missing key filtering."""
@pytest.mark.parametrize(
"match_name,match_meta,nomatch_name,nomatch_meta,filter_key",
[
# Null matches missing key
("missing_key", {}, "has_key", {"optional": "value"}, "optional"),
# Null matches explicit null
("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"),
],
ids=["missing_key", "explicit_null"],
)
def test_null_filter_matches(
self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, match_name, match_meta)
_make_asset_info(session, asset, nomatch_name, nomatch_meta)
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={filter_key: None})
assert total == 1
assert infos[0].name == match_name
class TestMetadataFilterList:
"""Tests for list-based (OR) filtering."""
def test_filter_by_list_matches_any(self, session: Session):
"""List values should match ANY of the values (OR)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "cat_a", {"category": "a"})
_make_asset_info(session, asset, "cat_b", {"category": "b"})
_make_asset_info(session, asset, "cat_c", {"category": "c"})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"category": ["a", "b"]})
assert total == 2
names = {i.name for i in infos}
assert names == {"cat_a", "cat_b"}
class TestMetadataFilterMultipleKeys:
"""Tests for multiple filter keys (AND semantics)."""
def test_multiple_keys_must_all_match(self, session: Session):
"""Multiple keys should ALL match (AND)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", {"type": "model", "version": 2})
_make_asset_info(session, asset, "wrong_type", {"type": "config", "version": 2})
_make_asset_info(session, asset, "wrong_version", {"type": "model", "version": 1})
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={"type": "model", "version": 2}
)
assert total == 1
assert infos[0].name == "match"
class TestMetadataFilterEmptyDict:
"""Tests for empty filter behavior."""
def test_empty_filter_returns_all(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "a", {"key": "val"})
_make_asset_info(session, asset, "b", {})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={})
assert total == 2

View File

@@ -0,0 +1,366 @@
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoTag, AssetInfoMeta, Tag
from app.assets.database.queries import (
ensure_tags_exist,
get_asset_tags,
set_asset_info_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
bulk_insert_tags_and_meta,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestEnsureTagsExist:
def test_creates_new_tags(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_is_idempotent(self, session: Session):
ensure_tags_exist(session, ["alpha"], tag_type="user")
ensure_tags_exist(session, ["alpha"], tag_type="user")
session.commit()
assert session.query(Tag).count() == 1
def test_normalizes_tags(self, session: Session):
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
session.commit()
assert session.query(Tag).count() == 0
def test_tag_type_is_set(self, session: Session):
ensure_tags_exist(session, ["system-tag"], tag_type="system")
session.commit()
tag = session.query(Tag).filter_by(name="system-tag").one()
assert tag.tag_type == "system"
class TestGetAssetTags:
def test_returns_empty_for_no_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
tags = get_asset_tags(session, asset_info_id=info.id)
assert tags == []
def test_returns_tags_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["tag1", "tag2"])
session.add_all([
AssetInfoTag(asset_info_id=info.id, tag_name="tag1", origin="manual", added_at=get_utc_now()),
AssetInfoTag(asset_info_id=info.id, tag_name="tag2", origin="manual", added_at=get_utc_now()),
])
session.flush()
tags = get_asset_tags(session, asset_info_id=info.id)
assert set(tags) == {"tag1", "tag2"}
class TestSetAssetInfoTags:
def test_adds_new_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["added"]) == {"a", "b"}
assert result["removed"] == []
assert set(result["total"]) == {"a", "b"}
def test_removes_old_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a"])
session.commit()
assert result["added"] == []
assert set(result["removed"]) == {"b", "c"}
assert result["total"] == ["a"]
def test_replaces_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["b", "c"])
session.commit()
assert result["added"] == ["c"]
assert result["removed"] == ["a"]
assert set(result["total"]) == {"b", "c"}
class TestAddTagsToAssetInfo:
def test_adds_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert set(result["added"]) == {"x", "y"}
assert result["already_present"] == []
def test_reports_already_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x"])
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert result["added"] == ["y"]
assert result["already_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
add_tags_to_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestRemoveTagsFromAssetInfo:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["removed"]) == {"a", "b"}
assert result["not_present"] == []
assert result["total_tags"] == ["c"]
def test_reports_not_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "x"])
session.commit()
assert result["removed"] == ["a"]
assert result["not_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
remove_tags_from_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestMissingTagFunctions:
def test_add_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" in tags
def test_add_missing_tag_is_idempotent(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="missing").all()
assert len(links) == 1
def test_remove_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" not in tags
class TestListTagsWithUsage:
def test_returns_tags_with_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_exclude_zero_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session, include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags_with_usage(session, order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_owner_visibility(self, session: Session):
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
asset = _make_asset(session, "hash1")
shared_info = _make_asset_info(session, asset, name="shared", owner_id="")
owner_info = _make_asset_info(session, asset, name="owned", owner_id="user1")
add_tags_to_asset_info(session, asset_info_id=shared_info.id, tags=["shared-tag"])
add_tags_to_asset_info(session, asset_info_id=owner_info.id, tags=["owner-tag"])
session.commit()
# Empty owner sees only shared
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 0
# User1 sees both
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 1
class TestBulkInsertTagsAndMeta:
def test_inserts_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_info_id": info.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now},
{"asset_info_id": info.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert set(tags) == {"bulk-tag1", "bulk-tag2"}
def test_inserts_meta(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
meta_rows = [
{
"asset_info_id": info.id,
"key": "meta-key",
"ordinal": 0,
"val_str": "meta-value",
"val_num": None,
"val_bool": None,
"val_json": None,
},
]
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows)
session.commit()
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "meta-key"
assert meta[0].val_str == "meta-value"
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["existing-tag"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing-tag"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_info_id": info.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
# Should still have only one tag link
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="existing-tag").all()
assert len(links) == 1
# Origin should be original, not overwritten
assert links[0].origin == "manual"
def test_empty_lists_is_noop(self, session: Session):
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[])
assert session.query(AssetInfoTag).count() == 0
assert session.query(AssetInfoMeta).count() == 0

View File

@@ -0,0 +1 @@
# Service layer tests

View File

@@ -0,0 +1,48 @@
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def db_engine():
"""In-memory SQLite engine for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session(db_engine):
"""Session fixture for tests that need direct DB access."""
with Session(db_engine) as sess:
yield sess
@pytest.fixture
def mock_create_session(db_engine):
"""Patch create_session to use our in-memory database."""
from contextlib import contextmanager
from sqlalchemy.orm import Session as SASession
@contextmanager
def _create_session():
with SASession(db_engine) as sess:
yield sess
with patch("app.assets.services.ingest.create_session", _create_session), \
patch("app.assets.services.asset_management.create_session", _create_session), \
patch("app.assets.services.tagging.create_session", _create_session):
yield _create_session
@pytest.fixture
def temp_dir():
"""Temporary directory for file operations."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)

View File

@@ -0,0 +1,264 @@
"""Tests for asset_management services."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
from app.assets.helpers import get_utc_now
from app.assets.services import (
get_asset_detail,
update_asset_metadata,
delete_asset_reference,
set_asset_preview,
)
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestGetAssetDetail:
def test_returns_none_for_nonexistent(self, mock_create_session):
result = get_asset_detail(asset_info_id="nonexistent")
assert result is None
def test_returns_asset_with_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
session.commit()
result = get_asset_detail(asset_info_id=info.id)
assert result is not None
assert result.info.id == info.id
assert result.asset.hash == asset.hash
assert set(result.tags) == {"alpha", "beta"}
def test_respects_owner_visibility(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
# Wrong owner cannot see
result = get_asset_detail(asset_info_id=info.id, owner_id="user2")
assert result is None
# Correct owner can see
result = get_asset_detail(asset_info_id=info.id, owner_id="user1")
assert result is not None
class TestUpdateAssetMetadata:
def test_updates_name(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, name="old_name.bin")
info_id = info.id
session.commit()
update_asset_metadata(
asset_info_id=info_id,
name="new_name.bin",
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.name == "new_name.bin"
def test_updates_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["old"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["old"])
session.commit()
result = update_asset_metadata(
asset_info_id=info.id,
tags=["new1", "new2"],
)
assert set(result.tags) == {"new1", "new2"}
assert "old" not in result.tags
def test_updates_user_metadata(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
info_id = info.id
session.commit()
update_asset_metadata(
asset_info_id=info_id,
user_metadata={"key": "value", "num": 42},
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.user_metadata["key"] == "value"
assert updated_info.user_metadata["num"] == 42
def test_raises_for_nonexistent(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
update_asset_metadata(asset_info_id="nonexistent", name="fail")
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
update_asset_metadata(
asset_info_id=info.id,
name="new",
owner_id="user2",
)
class TestDeleteAssetReference:
def test_deletes_asset_info(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
info_id = info.id
session.commit()
result = delete_asset_reference(
asset_info_id=info_id,
owner_id="",
delete_content_if_orphan=False,
)
assert result is True
assert session.get(AssetInfo, info_id) is None
def test_returns_false_for_nonexistent(self, mock_create_session):
result = delete_asset_reference(
asset_info_id="nonexistent",
owner_id="",
)
assert result is False
def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
info_id = info.id
session.commit()
result = delete_asset_reference(
asset_info_id=info_id,
owner_id="user2",
)
assert result is False
assert session.get(AssetInfo, info_id) is not None
def test_keeps_asset_if_other_infos_exist(self, mock_create_session, session: Session):
asset = _make_asset(session)
info1 = _make_asset_info(session, asset, name="info1")
_make_asset_info(session, asset, name="info2") # Second info keeps asset alive
asset_id = asset.id
session.commit()
delete_asset_reference(
asset_info_id=info1.id,
owner_id="",
delete_content_if_orphan=True,
)
# Asset should still exist
assert session.get(Asset, asset_id) is not None
def test_deletes_orphaned_asset(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
asset_id = asset.id
info_id = info.id
session.commit()
delete_asset_reference(
asset_info_id=info_id,
owner_id="",
delete_content_if_orphan=True,
)
# Both info and asset should be gone
assert session.get(AssetInfo, info_id) is None
assert session.get(Asset, asset_id) is None
class TestSetAssetPreview:
def test_sets_preview(self, mock_create_session, session: Session):
asset = _make_asset(session, hash_val="blake3:main")
preview_asset = _make_asset(session, hash_val="blake3:preview")
info = _make_asset_info(session, asset)
info_id = info.id
preview_id = preview_asset.id
session.commit()
set_asset_preview(
asset_info_id=info_id,
preview_asset_id=preview_id,
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.preview_id == preview_id
def test_clears_preview(self, mock_create_session, session: Session):
asset = _make_asset(session)
preview_asset = _make_asset(session, hash_val="blake3:preview")
info = _make_asset_info(session, asset)
info.preview_id = preview_asset.id
info_id = info.id
session.commit()
set_asset_preview(
asset_info_id=info_id,
preview_asset_id=None,
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.preview_id is None
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
set_asset_preview(asset_info_id="nonexistent")
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
set_asset_preview(
asset_info_id=info.id,
preview_asset_id=None,
owner_id="user2",
)

View File

@@ -0,0 +1,138 @@
"""Tests for bulk ingest services."""
from pathlib import Path
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
class TestBatchInsertSeedAssets:
def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path):
"""Verify mime_type is stored in the Asset table for model files."""
file_path = temp_dir / "model.safetensors"
file_path.write_bytes(b"fake safetensors content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 24,
"mtime_ns": 1234567890000000000,
"info_name": "Test Model",
"tags": ["models"],
"fname": "model.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_infos == 1
# Verify Asset has mime_type populated
assets = session.query(Asset).all()
assert len(assets) == 1
assert assets[0].mime_type == "application/safetensors"
def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path):
"""Verify mime_type is None when not provided in spec."""
file_path = temp_dir / "unknown.bin"
file_path.write_bytes(b"binary data")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 11,
"mtime_ns": 1234567890000000000,
"info_name": "Unknown File",
"tags": [],
"fname": "unknown.bin",
"metadata": None,
"hash": None,
"mime_type": None,
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_infos == 1
assets = session.query(Asset).all()
assert len(assets) == 1
assert assets[0].mime_type is None
def test_various_model_mime_types(self, session: Session, temp_dir: Path):
"""Verify various model file types get correct mime_type."""
test_cases = [
("model.safetensors", "application/safetensors"),
("model.pt", "application/pytorch"),
("model.ckpt", "application/pickle"),
("model.gguf", "application/gguf"),
]
specs: list[SeedAssetSpec] = []
for filename, mime_type in test_cases:
file_path = temp_dir / filename
file_path.write_bytes(b"content")
specs.append(
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": filename,
"tags": [],
"fname": filename,
"metadata": None,
"hash": None,
"mime_type": mime_type,
}
)
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_infos == len(test_cases)
for filename, expected_mime in test_cases:
from app.assets.database.models import AssetInfo
info = session.query(AssetInfo).filter_by(name=filename).first()
assert info is not None
asset = session.query(Asset).filter_by(id=info.asset_id).first()
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
class TestMetadataExtraction:
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
"""Verify metadata extraction returns correct mime_type for model files."""
from app.assets.services.metadata_extract import extract_file_metadata
file_path = temp_dir / "model.safetensors"
file_path.write_bytes(b"fake safetensors content")
meta = extract_file_metadata(str(file_path))
assert meta.content_type == "application/safetensors"
def test_mime_type_for_various_model_formats(self, temp_dir: Path):
"""Verify various model file types get correct mime_type from metadata."""
from app.assets.services.metadata_extract import extract_file_metadata
test_cases = [
("model.safetensors", "application/safetensors"),
("model.sft", "application/safetensors"),
("model.pt", "application/pytorch"),
("model.pth", "application/pytorch"),
("model.ckpt", "application/pickle"),
("model.pkl", "application/pickle"),
("model.gguf", "application/gguf"),
]
for filename, expected_mime in test_cases:
file_path = temp_dir / filename
file_path.write_bytes(b"content")
meta = extract_file_metadata(str(file_path))
assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}"

View File

@@ -0,0 +1,227 @@
"""Tests for ingest services."""
from pathlib import Path
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, Tag
from app.assets.database.queries import get_asset_tags
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
class TestIngestFileFromPath:
def test_creates_asset_and_cache_state(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "test_file.bin"
file_path.write_bytes(b"test content")
result = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:abc123",
size_bytes=12,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
)
assert result.asset_created is True
assert result.state_created is True
assert result.asset_info_id is None # no info_name provided
# Verify DB state
assets = session.query(Asset).all()
assert len(assets) == 1
assert assets[0].hash == "blake3:abc123"
states = session.query(AssetCacheState).all()
assert len(states) == 1
assert states[0].file_path == str(file_path)
def test_creates_asset_info_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "model.safetensors"
file_path.write_bytes(b"model data")
result = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:def456",
size_bytes=10,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
info_name="My Model",
owner_id="user1",
)
assert result.asset_created is True
assert result.asset_info_id is not None
info = session.query(AssetInfo).first()
assert info is not None
assert info.name == "My Model"
assert info.owner_id == "user1"
def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "tagged.bin"
file_path.write_bytes(b"data")
result = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:ghi789",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Tagged Asset",
tags=["models", "checkpoints"],
)
assert result.asset_info_id is not None
# Verify tags were created and linked
tags = session.query(Tag).all()
tag_names = {t.name for t in tags}
assert "models" in tag_names
assert "checkpoints" in tag_names
asset_tags = get_asset_tags(session, asset_info_id=result.asset_info_id)
assert set(asset_tags) == {"models", "checkpoints"}
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "dup.bin"
file_path.write_bytes(b"content")
# First ingest
r1 = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:repeat",
size_bytes=7,
mtime_ns=1234567890000000000,
)
assert r1.asset_created is True
# Second ingest with same hash - should update, not create
r2 = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:repeat",
size_bytes=7,
mtime_ns=1234567890000000001, # different mtime
)
assert r2.asset_created is False
assert r2.state_updated is True or r2.state_created is False
# Still only one asset
assets = session.query(Asset).all()
assert len(assets) == 1
def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "with_preview.bin"
file_path.write_bytes(b"data")
# Create a preview asset first
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
session.add(preview_asset)
session.commit()
preview_id = preview_asset.id
result = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:main",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="With Preview",
preview_id=preview_id,
)
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id == preview_id
def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "bad_preview.bin"
file_path.write_bytes(b"data")
result = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:badpreview",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Bad Preview",
preview_id="nonexistent-uuid",
)
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id is None
class TestRegisterExistingAsset:
def test_creates_info_for_existing_asset(self, mock_create_session, session: Session):
# Create existing asset
asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png")
session.add(asset)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:existing",
name="Registered Asset",
user_metadata={"key": "value"},
tags=["models"],
)
assert result.created is True
assert "models" in result.tags
# Verify by re-fetching from DB
session.expire_all()
infos = session.query(AssetInfo).filter_by(name="Registered Asset").all()
assert len(infos) == 1
def test_returns_existing_info(self, mock_create_session, session: Session):
# Create asset and info
asset = Asset(hash="blake3:withinfo", size_bytes=512)
session.add(asset)
session.flush()
from app.assets.helpers import get_utc_now
info = AssetInfo(
owner_id="",
name="Existing Info",
asset_id=asset.id,
created_at=get_utc_now(),
updated_at=get_utc_now(),
last_access_time=get_utc_now(),
)
session.add(info)
session.flush() # Flush to get the ID
info_id = info.id
session.commit()
result = _register_existing_asset(
asset_hash="blake3:withinfo",
name="Existing Info",
owner_id="",
)
assert result.created is False
# Verify only one AssetInfo exists for this name
session.expire_all()
infos = session.query(AssetInfo).filter_by(name="Existing Info").all()
assert len(infos) == 1
assert infos[0].id == info_id
def test_raises_for_nonexistent_hash(self, mock_create_session):
with pytest.raises(ValueError, match="No asset with hash"):
_register_existing_asset(
asset_hash="blake3:doesnotexist",
name="Fail",
)
def test_applies_tags_to_new_info(self, mock_create_session, session: Session):
asset = Asset(hash="blake3:tagged", size_bytes=256)
session.add(asset)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:tagged",
name="Tagged Info",
tags=["alpha", "beta"],
)
assert result.created is True
assert set(result.tags) == {"alpha", "beta"}

View File

@@ -0,0 +1,197 @@
"""Tests for tagging services."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
from app.assets.helpers import get_utc_now
from app.assets.services import apply_tags, remove_tags, list_tags
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestApplyTags:
def test_adds_new_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
session.commit()
result = apply_tags(
asset_info_id=info.id,
tags=["alpha", "beta"],
)
assert set(result.added) == {"alpha", "beta"}
assert result.already_present == []
assert set(result.total_tags) == {"alpha", "beta"}
def test_reports_already_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["existing"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing"])
session.commit()
result = apply_tags(
asset_info_id=info.id,
tags=["existing", "new"],
)
assert result.added == ["new"]
assert result.already_present == ["existing"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
apply_tags(asset_info_id="nonexistent", tags=["x"])
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
apply_tags(
asset_info_id=info.id,
tags=["new"],
owner_id="user2",
)
class TestRemoveTags:
def test_removes_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["a", "b", "c"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
session.commit()
result = remove_tags(
asset_info_id=info.id,
tags=["a", "b"],
)
assert set(result.removed) == {"a", "b"}
assert result.not_present == []
assert result.total_tags == ["c"]
def test_reports_not_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["present"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["present"])
session.commit()
result = remove_tags(
asset_info_id=info.id,
tags=["present", "absent"],
)
assert result.removed == ["present"]
assert result.not_present == ["absent"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
remove_tags(asset_info_id="nonexistent", tags=["x"])
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
remove_tags(
asset_info_id=info.id,
tags=["x"],
owner_id="user2",
)
class TestListTags:
def test_returns_tags_with_counts(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session)
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags()
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_excludes_zero_counts(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session)
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags(include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, _ = list_tags(prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags(order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_pagination(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["a", "b", "c", "d", "e"])
session.commit()
rows, total = list_tags(limit=2, offset=1, order="name_asc")
assert total == 5
assert len(rows) == 2
names = [name for name, _, _ in rows]
assert names == ["b", "c"]
def test_clamps_limit(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["a"])
session.commit()
# Service should clamp limit to max 1000
rows, _ = list_tags(limit=2000)
assert len(rows) <= 1000

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets
def test_create_from_hash_success(
@@ -126,42 +126,52 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api
assert body == b""
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
bogus = str(uuid.uuid4())
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
@pytest.mark.parametrize(
"method,endpoint_template,payload,expected_status,error_code",
[
# Delete nonexistent asset
("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"),
# Bad hash algorithm in from-hash
(
"post",
"/api/assets/from-hash",
{"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]},
400,
"INVALID_BODY",
),
# Get with bad UUID format
("get", "/api/assets/not-a-uuid", None, 404, None),
# Get content with bad UUID format
("get", "/api/assets/not-a-uuid/content", None, 404, None),
],
ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"],
)
def test_error_responses(
http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code
):
# Replace {uuid} placeholder with a random UUID for delete test
endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4()))
url = f"{api_base}{endpoint}"
if method == "get":
r = http.get(url, timeout=120)
elif method == "post":
r = http.post(url, json=payload, timeout=120)
elif method == "delete":
r = http.delete(url, timeout=120)
assert r.status_code == expected_status
if error_code:
body = r.json()
assert body["error"]["code"] == error_code
def test_create_from_hash_invalid_json(http: requests.Session, api_base: str):
"""Invalid JSON body requires special handling (data= instead of json=)."""
r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
body = r.json()
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
# Bad hash algorithm
bad = {
"hash": "sha256:" + "0" * 64,
"name": "x.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_BODY"
# Invalid JSON body
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_JSON"
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
# All endpoints should be not found, as we UUID regex directly in the route definition.
bad_id = "not-a-uuid"
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
assert r1.status_code == 404
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
assert r3.status_code == 404
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_JSON"
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@@ -6,7 +6,7 @@ from typing import Optional
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@@ -0,0 +1,55 @@
from app.assets.services.file_utils import is_visible, list_files_recursively
class TestIsVisible:
def test_visible_file(self):
assert is_visible("file.txt") is True
def test_hidden_file(self):
assert is_visible(".hidden") is False
def test_hidden_directory(self):
assert is_visible(".git") is False
def test_visible_directory(self):
assert is_visible("src") is True
def test_dotdot_is_hidden(self):
assert is_visible("..") is False
def test_dot_is_hidden(self):
assert is_visible(".") is False
class TestListFilesRecursively:
def test_skips_hidden_files(self, tmp_path):
(tmp_path / "visible.txt").write_text("a")
(tmp_path / ".hidden").write_text("b")
result = list_files_recursively(str(tmp_path))
assert len(result) == 1
assert result[0].endswith("visible.txt")
def test_skips_hidden_directories(self, tmp_path):
hidden_dir = tmp_path / ".hidden_dir"
hidden_dir.mkdir()
(hidden_dir / "file.txt").write_text("a")
visible_dir = tmp_path / "visible_dir"
visible_dir.mkdir()
(visible_dir / "file.txt").write_text("b")
result = list_files_recursively(str(tmp_path))
assert len(result) == 1
assert "visible_dir" in result[0]
assert ".hidden_dir" not in result[0]
def test_empty_directory(self, tmp_path):
result = list_files_recursively(str(tmp_path))
assert result == []
def test_nonexistent_directory(self, tmp_path):
result = list_files_recursively(str(tmp_path / "nonexistent"))
assert result == []

View File

@@ -1,6 +1,7 @@
import time
import uuid
import pytest
import requests
@@ -283,30 +284,21 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse
assert b2["has_more"] is False
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
# limit too small
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
# bad metadata JSON
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
@pytest.mark.parametrize(
"params,error_code",
[
({"offset": "-1"}, "INVALID_QUERY"),
({"limit": "abc"}, "INVALID_QUERY"),
({"limit": "0"}, "INVALID_QUERY"),
({"metadata_filter": "{not json"}, "INVALID_QUERY"),
],
ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"],
)
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code):
r = http.get(api_base + "/api/assets", params=params, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == error_code
def test_list_assets_name_contains_literal_underscore(

View File

@@ -3,7 +3,7 @@ from pathlib import Path
import pytest
import requests
from conftest import get_asset_filename, trigger_sync_seed_assets
from helpers import get_asset_filename, trigger_sync_seed_assets
@pytest.fixture

View File

@@ -0,0 +1,423 @@
"""Unit tests for the AssetSeeder background scanning class."""
import threading
import time
from unittest.mock import patch
import pytest
from app.assets.seeder import AssetSeeder, Progress, State
@pytest.fixture
def fresh_seeder():
"""Create a fresh AssetSeeder instance for testing (bypasses singleton)."""
seeder = object.__new__(AssetSeeder)
seeder._initialized = False
seeder.__init__()
yield seeder
seeder.shutdown(timeout=1.0)
@pytest.fixture
def mock_dependencies():
"""Mock all external dependencies for isolated testing."""
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
):
yield
class TestSeederStateTransitions:
"""Test state machine transitions."""
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder):
assert fresh_seeder.get_status().state == State.IDLE
def test_start_transitions_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
started = fresh_seeder.start(roots=("models",))
assert started is True
status = fresh_seeder.get_status()
assert status.state in (State.RUNNING, State.IDLE)
def test_start_while_running_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
second_start = fresh_seeder.start(roots=("models",))
assert second_start is False
barrier.set()
def test_cancel_transitions_to_cancelling(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
cancelled = fresh_seeder.cancel()
assert cancelled is True
assert fresh_seeder.get_status().state == State.CANCELLING
barrier.set()
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
cancelled = fresh_seeder.cancel()
assert cancelled is False
def test_state_returns_to_idle_after_completion(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
assert completed is True
assert fresh_seeder.get_status().state == State.IDLE
class TestSeederWait:
"""Test wait() behavior."""
def test_wait_blocks_until_complete(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
assert completed is True
assert fresh_seeder.get_status().state == State.IDLE
def test_wait_returns_false_on_timeout(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=10.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=0.1)
assert completed is False
barrier.set()
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder):
completed = fresh_seeder.wait(timeout=1.0)
assert completed is True
class TestSeederProgress:
"""Test progress tracking."""
def test_get_status_returns_progress_during_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
progress_seen = []
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=5.0)
return ["/path/file1.safetensors", "/path/file2.safetensors"]
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
status = fresh_seeder.get_status()
assert status.progress is not None
progress_seen.append(status.progress)
barrier.set()
def test_progress_callback_is_invoked(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
progress_updates: list[Progress] = []
def callback(p: Progress):
progress_updates.append(p)
with patch(
"app.assets.seeder.collect_paths_for_roots",
return_value=[f"/path/file{i}.safetensors" for i in range(10)],
):
fresh_seeder.start(roots=("models",), progress_callback=callback)
fresh_seeder.wait(timeout=5.0)
assert len(progress_updates) > 0
class TestSeederCancellation:
"""Test cancellation behavior."""
def test_scan_commits_partial_progress_on_cancellation(
self, fresh_seeder: AssetSeeder
):
insert_count = 0
barrier = threading.Event()
def slow_insert(specs, tags):
nonlocal insert_count
insert_count += 1
if insert_count >= 2:
barrier.wait(timeout=5.0)
return len(specs)
paths = [f"/path/file{i}.safetensors" for i in range(1500)]
specs = [
{
"abs_path": p,
"size_bytes": 100,
"mtime_ns": 0,
"info_name": f"file{i}",
"tags": [],
"fname": f"file{i}",
}
for i, p in enumerate(paths)
]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch(
"app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
),
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
):
fresh_seeder.start(roots=("models",))
time.sleep(0.1)
fresh_seeder.cancel()
barrier.set()
fresh_seeder.wait(timeout=5.0)
assert insert_count >= 1
class TestSeederErrorHandling:
"""Test error handling behavior."""
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch(
"app.assets.seeder.collect_paths_for_roots",
return_value=["/path/file.safetensors"],
),
patch(
"app.assets.seeder.build_asset_specs",
return_value=(
[
{
"abs_path": "/path/file.safetensors",
"size_bytes": 100,
"mtime_ns": 0,
"info_name": "file",
"tags": [],
"fname": "file",
}
],
set(),
0,
),
),
patch(
"app.assets.seeder.insert_asset_specs",
side_effect=Exception("DB connection failed"),
),
):
fresh_seeder.start(roots=("models",))
fresh_seeder.wait(timeout=5.0)
status = fresh_seeder.get_status()
assert len(status.errors) > 0
assert "DB connection failed" in status.errors[0]
def test_dependencies_unavailable_captured_in_errors(
self, fresh_seeder: AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
fresh_seeder.start(roots=("models",))
fresh_seeder.wait(timeout=5.0)
status = fresh_seeder.get_status()
assert len(status.errors) > 0
assert "dependencies" in status.errors[0].lower()
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
"app.assets.seeder.sync_root_safely",
side_effect=RuntimeError("Unexpected crash"),
),
):
fresh_seeder.start(roots=("models",))
fresh_seeder.wait(timeout=5.0)
status = fresh_seeder.get_status()
assert status.state == State.IDLE
assert len(status.errors) > 0
class TestSeederThreadSafety:
"""Test thread safety of concurrent operations."""
def test_concurrent_start_calls_spawn_only_one_thread(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
results = []
def try_start():
results.append(fresh_seeder.start(roots=("models",)))
threads = [threading.Thread(target=try_start) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
barrier.set()
assert sum(results) == 1
def test_get_status_safe_during_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
statuses = []
for _ in range(100):
statuses.append(fresh_seeder.get_status())
time.sleep(0.001)
barrier.set()
assert all(
s.state in (State.RUNNING, State.IDLE, State.CANCELLING)
for s in statuses
)
class TestSeederMarkMissing:
"""Test mark_missing_outside_prefixes behavior."""
def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
"app.assets.seeder.get_all_known_prefixes",
return_value=["/models", "/input", "/output"],
),
patch(
"app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5
) as mock_mark,
):
result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 5
mock_mark.assert_called_once_with(["/models", "/input", "/output"])
def test_mark_missing_returns_zero_when_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
time.sleep(0.05)
result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 0
barrier.set()
def test_mark_missing_returns_zero_when_dependencies_unavailable(
self, fresh_seeder: AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 0
def test_prune_first_flag_triggers_mark_missing_before_scan(
self, fresh_seeder: AssetSeeder
):
call_order = []
def track_mark(prefixes):
call_order.append("mark_missing")
return 3
def track_sync(root):
call_order.append(f"sync_{root}")
return set()
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]),
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
):
fresh_seeder.start(roots=("models",), prune_first=True)
fresh_seeder.wait(timeout=5.0)
assert call_order[0] == "mark_missing"
assert "sync_models" in call_order