mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 10:40:04 +00:00
Compare commits
56 Commits
v0.12.2
...
luke-mino-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
043a75acde | ||
|
|
dcba47251a | ||
|
|
8f7362d8b0 | ||
|
|
0121a5532e | ||
|
|
0ff02860bc | ||
|
|
0bc3a6a377 | ||
|
|
59bff14c97 | ||
|
|
9b284163f5 | ||
|
|
593cc980e9 | ||
|
|
72b6c3f065 | ||
|
|
2af0cf18f5 | ||
|
|
6edcd690b6 | ||
|
|
6708c02446 | ||
|
|
58f70f2f92 | ||
|
|
32011c403b | ||
|
|
58ddf46c0a | ||
|
|
9222ff6d81 | ||
|
|
ebb2f5b0e9 | ||
|
|
28c4b58dd6 | ||
|
|
56e9a75ca2 | ||
|
|
8fb77c080f | ||
|
|
fa3749ced7 | ||
|
|
16b5d9112b | ||
|
|
abeec3072b | ||
|
|
b23302f372 | ||
|
|
adf6eb73fd | ||
|
|
5259959fef | ||
|
|
5474d8bf84 | ||
|
|
9290e26e9f | ||
|
|
37ecc5b663 | ||
|
|
80d99e7b63 | ||
|
|
d8cb122dfb | ||
|
|
0f75def5b5 | ||
|
|
6b1f9f7755 | ||
|
|
3311b13740 | ||
|
|
bf7fbb6317 | ||
|
|
5571508e61 | ||
|
|
e3b8e512ca | ||
|
|
ea01cd665d | ||
|
|
ccfc5dedd4 | ||
|
|
e9ca190098 | ||
|
|
ed60e93696 | ||
|
|
fef2f01671 | ||
|
|
481a2fa263 | ||
|
|
11ca1995a3 | ||
|
|
4e02245012 | ||
|
|
9f9db2c2c2 | ||
|
|
e987bd268f | ||
|
|
2eb100adf9 | ||
|
|
a02f160e20 | ||
|
|
c3105b1174 | ||
|
|
64d2f51dfc | ||
|
|
fba4570e49 | ||
|
|
15ee03f65c | ||
|
|
70a600baf0 | ||
|
|
17ad7e393f |
37
alembic_db/versions/0002_add_is_missing_to_cache_state.py
Normal file
37
alembic_db/versions/0002_add_is_missing_to_cache_state.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Add is_missing column to asset_cache_state for non-destructive soft-delete
|
||||
|
||||
Revision ID: 0002_add_is_missing
|
||||
Revises: 0001_assets
|
||||
Create Date: 2025-02-05 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0002_add_is_missing"
|
||||
down_revision = "0001_assets"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"asset_cache_state",
|
||||
sa.Column(
|
||||
"is_missing",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_cache_state_is_missing",
|
||||
"asset_cache_state",
|
||||
["is_missing"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_cache_state_is_missing", table_name="asset_cache_state")
|
||||
op.drop_column("asset_cache_state", "is_missing")
|
||||
@@ -1,19 +1,39 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
from app.assets.scanner import seed_assets
|
||||
|
||||
import folder_paths
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
from app.assets.api.schemas_in import (
|
||||
AssetValidationError,
|
||||
UploadError,
|
||||
)
|
||||
from app.assets.api.upload import (
|
||||
delete_temp_file_if_exists,
|
||||
parse_multipart_upload,
|
||||
)
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
apply_tags,
|
||||
asset_exists,
|
||||
create_from_hash,
|
||||
delete_asset_reference,
|
||||
get_asset_detail,
|
||||
list_assets_page,
|
||||
list_tags,
|
||||
remove_tags,
|
||||
resolve_asset_for_download,
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@@ -21,36 +41,80 @@ USER_MANAGER: user_manager.UserManager | None = None
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key)
|
||||
if len(request.query.getall(key)) > 1
|
||||
else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
|
||||
|
||||
# Note to any custom node developers reading this code:
|
||||
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
|
||||
def register_assets_system(
|
||||
app: web.Application, user_manager_instance: user_manager.UserManager
|
||||
) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
def _build_error_response(
|
||||
status: int, code: str, message: str, details: dict | None = None
|
||||
) -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"code": code, "message": message, "details": details or {}}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
import json
|
||||
errors = json.loads(ve.json())
|
||||
return _build_error_response(400, code, "Validation failed.", {"errors": errors})
|
||||
|
||||
|
||||
def _validate_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
return _build_error_response(
|
||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
||||
)
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
if (
|
||||
algo != "blake3"
|
||||
or not digest
|
||||
or any(c for c in digest if c not in "0123456789abcdef")
|
||||
):
|
||||
return _build_error_response(
|
||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
||||
)
|
||||
exists = asset_exists(hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
async def list_assets_route(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list assets.
|
||||
"""
|
||||
@@ -58,66 +122,117 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
return _build_validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = manager.list_assets(
|
||||
sort = _validate_sort_field(q.sort)
|
||||
order_candidate = (q.order or "desc").lower()
|
||||
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
|
||||
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries = [
|
||||
schemas_out.AssetSummary(
|
||||
id=item.info.id,
|
||||
name=item.info.name,
|
||||
asset_hash=item.asset.hash if item.asset else None,
|
||||
size=int(item.asset.size_bytes) if item.asset else None,
|
||||
mime_type=item.asset.mime_type if item.asset else None,
|
||||
tags=item.tags,
|
||||
created_at=item.info.created_at,
|
||||
updated_at=item.info.updated_at,
|
||||
last_access_time=item.info.last_access_time,
|
||||
)
|
||||
for item in result.items
|
||||
]
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=result.total,
|
||||
has_more=(q.offset + len(summaries)) < result.total,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
async def get_asset_route(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get an asset's info as JSON.
|
||||
"""
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = manager.get_asset(
|
||||
result = get_asset_detail(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if not result:
|
||||
return _build_error_response(
|
||||
404,
|
||||
"ASSET_NOT_FOUND",
|
||||
f"AssetInfo {asset_info_id} not found",
|
||||
{"id": asset_info_id},
|
||||
)
|
||||
|
||||
payload = schemas_out.AssetDetail(
|
||||
id=result.info.id,
|
||||
name=result.info.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.info.user_metadata or {},
|
||||
preview_id=result.info.preview_id,
|
||||
created_at=result.info.created_at,
|
||||
last_access_time=result.info.last_access_time,
|
||||
)
|
||||
except ValueError as e:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
result = resolve_asset_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
abs_path = result.abs_path
|
||||
content_type = result.content_type
|
||||
filename = result.download_name
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
return _build_error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
return _build_error_response(
|
||||
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
|
||||
)
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
cd = f"{disposition}; filename=\"{quoted}\"; filename*=UTF-8''{urllib.parse.quote(quoted)}"
|
||||
|
||||
file_size = os.path.getsize(abs_path)
|
||||
logging.info(
|
||||
@@ -129,7 +244,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
filename,
|
||||
)
|
||||
|
||||
async def file_sender():
|
||||
async def stream_file_chunks():
|
||||
chunk_size = 64 * 1024
|
||||
with open(abs_path, "rb") as f:
|
||||
while True:
|
||||
@@ -139,7 +254,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
yield chunk
|
||||
|
||||
return web.Response(
|
||||
body=file_sender(),
|
||||
body=stream_file_chunks(),
|
||||
content_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": cd,
|
||||
@@ -149,16 +264,18 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
return _build_validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
return _build_error_response(
|
||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||
)
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
result = create_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
@@ -166,228 +283,183 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
||||
)
|
||||
|
||||
payload_out = schemas_out.AssetCreated(
|
||||
id=result.info.id,
|
||||
name=result.info.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.info.user_metadata or {},
|
||||
preview_id=result.info.preview_id,
|
||||
created_at=result.info.created_at,
|
||||
last_access_time=result.info.last_access_time,
|
||||
created_new=result.created_new,
|
||||
)
|
||||
return web.json_response(payload_out.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists)
|
||||
except UploadError as e:
|
||||
return _build_error_response(e.status, e.code, e.message)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate(
|
||||
{
|
||||
"tags": parsed.tags_raw,
|
||||
"name": parsed.provided_name,
|
||||
"user_metadata": parsed.user_metadata_raw,
|
||||
"hash": parsed.provided_hash,
|
||||
}
|
||||
)
|
||||
except ValidationError as ve:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(
|
||||
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
|
||||
)
|
||||
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if (
|
||||
len(spec.tags) < 2
|
||||
or spec.tags[1] not in folder_paths.folder_names_and_paths
|
||||
):
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
category = spec.tags[1] if len(spec.tags) >= 2 else ""
|
||||
return _build_error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{category}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and parsed.provided_hash_exists is True:
|
||||
result = create_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
if result is None:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist"
|
||||
)
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
else:
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not parsed.tmp_path or not os.path.exists(parsed.tmp_path):
|
||||
return _build_error_response(
|
||||
404,
|
||||
"ASSET_NOT_FOUND",
|
||||
"Provided hash not found and no file uploaded.",
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
result = upload_from_temp_path(
|
||||
temp_path=parsed.tmp_path,
|
||||
name=spec.name,
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
client_filename=parsed.file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_hash=spec.hash,
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except AssetValidationError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(400, e.code, str(e))
|
||||
except ValueError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(400, "BAD_REQUEST", str(e))
|
||||
except HashMismatchError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(400, "HASH_MISMATCH", str(e))
|
||||
except DependencyMissingError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(503, "DEPENDENCY_MISSING", e.message)
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
payload = schemas_out.AssetCreated(
|
||||
id=result.info.id,
|
||||
name=result.info.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.info.user_metadata or {},
|
||||
preview_id=result.info.preview_id,
|
||||
created_at=result.info.created_at,
|
||||
last_access_time=result.info.last_access_time,
|
||||
created_new=result.created_new,
|
||||
)
|
||||
status = 201 if result.created_new else 200
|
||||
return web.json_response(payload.model_dump(mode="json"), status=status)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
async def update_asset_route(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
return _build_validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
return _build_error_response(
|
||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||
)
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
result = update_asset_metadata(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
payload = schemas_out.AssetUpdated(
|
||||
id=result.info.id,
|
||||
name=result.info.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.info.user_metadata or {},
|
||||
updated_at=result.info.updated_at,
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
async def delete_asset_route(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
delete_content_param = request.query.get("delete_content")
|
||||
delete_content = (
|
||||
True
|
||||
if delete_content_param is None
|
||||
else delete_content_param.lower() not in {"0", "false", "no"}
|
||||
)
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
deleted = delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
@@ -398,10 +470,12 @@ async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found."
|
||||
)
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@@ -415,12 +489,12 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as e:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
|
||||
status=400,
|
||||
import json
|
||||
return _build_error_response(
|
||||
400, "INVALID_QUERY", "Invalid query parameters", {"errors": json.loads(e.json())}
|
||||
)
|
||||
|
||||
result = manager.list_tags(
|
||||
rows, total = list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
@@ -428,87 +502,201 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
tags = [
|
||||
schemas_out.TagUsage(name=name, count=count, type=tag_type)
|
||||
for (name, tag_type, count) in rows
|
||||
]
|
||||
payload = schemas_out.TagsList(
|
||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
json_payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(json_payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
return _build_error_response(
|
||||
400,
|
||||
"INVALID_BODY",
|
||||
"Invalid JSON body for tags add.",
|
||||
{"errors": ve.errors()},
|
||||
)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
return _build_error_response(
|
||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||
)
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
result = apply_tags(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
payload = schemas_out.TagsAdd(
|
||||
added=result.added,
|
||||
already_present=result.already_present,
|
||||
total_tags=result.total_tags,
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
json_payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(json_payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
return _build_error_response(
|
||||
400,
|
||||
"INVALID_BODY",
|
||||
"Invalid JSON body for tags remove.",
|
||||
{"errors": ve.errors()},
|
||||
)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
return _build_error_response(
|
||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||
)
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
result = remove_tags(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
payload = schemas_out.TagsRemove(
|
||||
removed=result.removed,
|
||||
not_present=result.not_present,
|
||||
total_tags=result.total_tags,
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output).
|
||||
|
||||
Query params:
|
||||
wait: If "true", block until scan completes (synchronous behavior for tests)
|
||||
|
||||
Returns:
|
||||
202 Accepted if scan started
|
||||
409 Conflict if scan already running
|
||||
200 OK with final stats if wait=true
|
||||
"""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
valid_roots = tuple(r for r in roots if r in ("models", "input", "output"))
|
||||
if not valid_roots:
|
||||
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
return _build_error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||
wait_param = request.query.get("wait", "").lower()
|
||||
should_wait = wait_param in ("true", "1", "yes")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
started = asset_seeder.start(roots=valid_roots)
|
||||
if not started:
|
||||
return web.json_response({"status": "already_running"}, status=409)
|
||||
|
||||
if should_wait:
|
||||
asset_seeder.wait()
|
||||
status = asset_seeder.get_status()
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "completed",
|
||||
"progress": {
|
||||
"scanned": status.progress.scanned if status.progress else 0,
|
||||
"total": status.progress.total if status.progress else 0,
|
||||
"created": status.progress.created if status.progress else 0,
|
||||
"skipped": status.progress.skipped if status.progress else 0,
|
||||
},
|
||||
"errors": status.errors,
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
|
||||
return web.json_response({"status": "started"}, status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/seed/status")
|
||||
async def get_seed_status(request: web.Request) -> web.Response:
|
||||
"""Get current scan status and progress."""
|
||||
status = asset_seeder.get_status()
|
||||
return web.json_response(
|
||||
{
|
||||
"state": status.state.value,
|
||||
"progress": {
|
||||
"scanned": status.progress.scanned,
|
||||
"total": status.progress.total,
|
||||
"created": status.progress.created,
|
||||
"skipped": status.progress.skipped,
|
||||
}
|
||||
if status.progress
|
||||
else None,
|
||||
"errors": status.errors,
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed/cancel")
|
||||
async def cancel_seed(request: web.Request) -> web.Response:
|
||||
"""Request cancellation of in-progress scan."""
|
||||
cancelled = asset_seeder.cancel()
|
||||
if cancelled:
|
||||
return web.json_response({"status": "cancelling"}, status=200)
|
||||
return web.json_response({"status": "idle"}, status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/prune")
|
||||
async def mark_missing_assets(request: web.Request) -> web.Response:
|
||||
"""Mark assets as missing when their cache states point to files outside all known root prefixes.
|
||||
|
||||
This is a non-destructive soft-delete operation. Assets and their metadata
|
||||
are preserved, but cache states are flagged as missing. They can be restored
|
||||
if the file reappears in a future scan.
|
||||
|
||||
Returns:
|
||||
200 OK with count of marked assets
|
||||
409 Conflict if a scan is currently running
|
||||
"""
|
||||
marked = asset_seeder.mark_missing_outside_prefixes()
|
||||
if marked == 0 and asset_seeder.get_status().state.value != "IDLE":
|
||||
return web.json_response(
|
||||
{"status": "scan_running", "marked": 0},
|
||||
status=409,
|
||||
)
|
||||
return web.json_response({"status": "completed", "marked": marked}, status=200)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
@@ -10,6 +11,65 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class UploadError(Exception):
|
||||
"""Error during upload parsing with HTTP status and code (used in HTTP layer only)."""
|
||||
|
||||
def __init__(self, status: int, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class AssetValidationError(Exception):
|
||||
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
|
||||
|
||||
def __init__(self, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class AssetNotFoundError(Exception):
|
||||
"""Asset or asset content not found."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class HashMismatchError(Exception):
|
||||
"""Uploaded file hash does not match provided hash."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class DependencyMissingError(Exception):
|
||||
"""A required dependency is not installed."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedUpload:
|
||||
"""Result of parsing a multipart upload request."""
|
||||
|
||||
file_present: bool
|
||||
file_written: int
|
||||
file_client_name: str | None
|
||||
tmp_path: str | None
|
||||
tags_raw: list[str]
|
||||
provided_name: str | None
|
||||
user_metadata_raw: str | None
|
||||
provided_hash: str | None
|
||||
provided_hash_exists: bool | None
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@@ -21,7 +81,9 @@ class ListAssetsQuery(BaseModel):
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
)
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@@ -61,7 +123,7 @@ class UpdateAssetBody(BaseModel):
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
def _validate_at_least_one_field(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
@@ -90,7 +152,7 @@ class CreateFromHashBody(BaseModel):
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
def _normalize_tags_field(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
@@ -163,6 +225,7 @@ class UploadAssetSpec(BaseModel):
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
@@ -260,5 +323,7 @@ class UploadAssetSpec(BaseModel):
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
raise ValueError(
|
||||
"models uploads require a category tag as the second tag"
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -19,7 +19,7 @@ class AssetSummary(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class AssetUpdated(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class AssetDetail(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
|
||||
170
app/assets/api/upload.py
Normal file
170
app/assets/api/upload.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
import folder_paths
|
||||
from app.assets.api.schemas_in import ParsedUpload, UploadError
|
||||
|
||||
|
||||
def normalize_and_validate_hash(s: str) -> str:
|
||||
"""
|
||||
Validate and normalize a hash string.
|
||||
|
||||
Returns canonical 'blake3:<hex>' or raises UploadError.
|
||||
"""
|
||||
s = s.strip().lower()
|
||||
if not s:
|
||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
if ":" not in s:
|
||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if (
|
||||
algo != "blake3"
|
||||
or not digest
|
||||
or any(c for c in digest if c not in "0123456789abcdef")
|
||||
):
|
||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
|
||||
async def parse_multipart_upload(
|
||||
request: web.Request,
|
||||
check_hash_exists: Callable[[str], bool],
|
||||
) -> ParsedUpload:
|
||||
"""
|
||||
Parse a multipart/form-data upload request.
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
|
||||
|
||||
Returns:
|
||||
ParsedUpload with parsed fields and temp file path
|
||||
|
||||
Raises:
|
||||
UploadError: On validation or I/O errors
|
||||
"""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
raise UploadError(
|
||||
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
|
||||
)
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
||||
)
|
||||
|
||||
if s:
|
||||
provided_hash = normalize_and_validate_hash(s)
|
||||
try:
|
||||
provided_hash_exists = check_hash_exists(provided_hash)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"check_hash_exists failed for hash=%s: %s", provided_hash, e
|
||||
)
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
|
||||
)
|
||||
continue
|
||||
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
|
||||
)
|
||||
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
|
||||
)
|
||||
|
||||
if (
|
||||
file_present
|
||||
and file_written == 0
|
||||
and not (provided_hash and provided_hash_exists)
|
||||
):
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
return ParsedUpload(
|
||||
file_present=file_present,
|
||||
file_written=file_written,
|
||||
file_client_name=file_client_name,
|
||||
tmp_path=tmp_path,
|
||||
tags_raw=tags_raw,
|
||||
provided_name=provided_name,
|
||||
user_metadata_raw=user_metadata_raw,
|
||||
provided_hash=provided_hash,
|
||||
provided_hash_exists=provided_hash_exists,
|
||||
)
|
||||
|
||||
|
||||
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
|
||||
"""Safely remove a temp file if it exists."""
|
||||
if tmp_path:
|
||||
try:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
except OSError as e:
|
||||
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)
|
||||
@@ -1,204 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
)
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
@@ -20,19 +20,21 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.database.models import Base, to_dict
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
@@ -75,17 +77,23 @@ class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
Index("ix_asset_cache_state_is_missing", "is_missing"),
|
||||
CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"
|
||||
),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
@@ -99,15 +107,29 @@ class AssetCacheState(Base):
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
asset_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False
|
||||
)
|
||||
preview_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
||||
)
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True)
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
@@ -143,7 +165,9 @@ class AssetInfo(Base):
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
UniqueConstraint(
|
||||
"asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"
|
||||
),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
@@ -196,7 +220,7 @@ class AssetInfoTag(Base):
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
@@ -225,9 +249,7 @@ class Tag(Base):
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
@@ -1,976 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
105
app/assets/database/queries/__init__.py
Normal file
105
app/assets/database/queries/__init__.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
upsert_asset,
|
||||
)
|
||||
from app.assets.database.queries.asset_info import (
|
||||
asset_info_exists_for_asset_id,
|
||||
bulk_insert_asset_infos_ignore_conflicts,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_info_by_id,
|
||||
get_asset_info_ids_by_ids,
|
||||
get_or_create_asset_info,
|
||||
insert_asset_info,
|
||||
list_asset_infos_page,
|
||||
set_asset_info_metadata,
|
||||
set_asset_info_preview,
|
||||
update_asset_info_access_time,
|
||||
update_asset_info_name,
|
||||
update_asset_info_timestamps,
|
||||
update_asset_info_updated_at,
|
||||
)
|
||||
from app.assets.database.queries.cache_state import (
|
||||
CacheStateRow,
|
||||
bulk_insert_cache_states_ignore_conflicts,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
delete_assets_by_ids,
|
||||
delete_cache_states_by_ids,
|
||||
delete_orphaned_seed_asset,
|
||||
get_cache_states_by_paths_and_asset_ids,
|
||||
get_cache_states_for_prefixes,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
list_cache_states_by_asset_id,
|
||||
mark_cache_states_missing_outside_prefixes,
|
||||
restore_cache_states_by_paths,
|
||||
upsert_cache_state,
|
||||
)
|
||||
from app.assets.database.queries.tags import (
|
||||
AddTagsDict,
|
||||
RemoveTagsDict,
|
||||
SetTagsDict,
|
||||
add_missing_tag_for_asset_id,
|
||||
add_tags_to_asset_info,
|
||||
bulk_insert_tags_and_meta,
|
||||
ensure_tags_exist,
|
||||
get_asset_tags,
|
||||
list_tags_with_usage,
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_asset_info,
|
||||
set_asset_info_tags,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsDict",
|
||||
"CacheStateRow",
|
||||
"RemoveTagsDict",
|
||||
"SetTagsDict",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"add_tags_to_asset_info",
|
||||
"asset_exists_by_hash",
|
||||
"asset_info_exists_for_asset_id",
|
||||
"bulk_insert_asset_infos_ignore_conflicts",
|
||||
"bulk_insert_assets",
|
||||
"bulk_insert_cache_states_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"delete_asset_info_by_id",
|
||||
"delete_assets_by_ids",
|
||||
"delete_cache_states_by_ids",
|
||||
"delete_orphaned_seed_asset",
|
||||
"ensure_tags_exist",
|
||||
"fetch_asset_info_and_asset",
|
||||
"fetch_asset_info_asset_and_tags",
|
||||
"get_asset_by_hash",
|
||||
"get_existing_asset_ids",
|
||||
"get_asset_info_by_id",
|
||||
"get_asset_info_ids_by_ids",
|
||||
"get_asset_tags",
|
||||
"get_cache_states_by_paths_and_asset_ids",
|
||||
"get_cache_states_for_prefixes",
|
||||
"get_or_create_asset_info",
|
||||
"get_unreferenced_unhashed_asset_ids",
|
||||
"insert_asset_info",
|
||||
"list_asset_infos_page",
|
||||
"list_cache_states_by_asset_id",
|
||||
"list_tags_with_usage",
|
||||
"mark_cache_states_missing_outside_prefixes",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_tags_from_asset_info",
|
||||
"restore_cache_states_by_paths",
|
||||
"set_asset_info_metadata",
|
||||
"set_asset_info_preview",
|
||||
"set_asset_info_tags",
|
||||
"update_asset_info_access_time",
|
||||
"update_asset_info_name",
|
||||
"update_asset_info_timestamps",
|
||||
"update_asset_info_updated_at",
|
||||
"upsert_asset",
|
||||
"upsert_cache_state",
|
||||
]
|
||||
103
app/assets/database/queries/asset.py
Normal file
103
app/assets/database/queries/asset.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries.common import calculate_rows_per_statement, iter_chunks
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True))
|
||||
.select_from(Asset)
|
||||
.where(Asset.hash == asset_hash)
|
||||
.limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def upsert_asset(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> tuple[Asset, bool, bool]:
|
||||
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
|
||||
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
|
||||
if mime_type:
|
||||
vals["mime_type"] = mime_type
|
||||
|
||||
ins = (
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
res = session.execute(ins)
|
||||
created = int(res.rowcount or 0) > 0
|
||||
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
|
||||
updated = False
|
||||
if not created:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
updated = True
|
||||
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
) -> None:
|
||||
"""Bulk insert Asset rows. Each dict should have: id, hash, size_bytes, mime_type, created_at."""
|
||||
if not rows:
|
||||
return
|
||||
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
|
||||
session.execute(ins, chunk)
|
||||
|
||||
|
||||
def get_existing_asset_ids(
|
||||
session: Session,
|
||||
asset_ids: list[str],
|
||||
) -> set[str]:
|
||||
"""Return the subset of asset_ids that exist in the database."""
|
||||
if not asset_ids:
|
||||
return set()
|
||||
rows = session.execute(
|
||||
select(Asset.id).where(Asset.id.in_(asset_ids))
|
||||
).fetchall()
|
||||
return {row[0] for row in rows}
|
||||
527
app/assets/database/queries/asset_info.py
Normal file
527
app/assets/database/queries/asset_info.py
Normal file
@@ -0,0 +1,527 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, exists, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
|
||||
from app.assets.database.models import (
|
||||
Asset,
|
||||
AssetInfo,
|
||||
AssetInfoMeta,
|
||||
AssetInfoTag,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
MAX_BIND_PARAMS,
|
||||
build_visible_owner_clause,
|
||||
calculate_rows_per_statement,
|
||||
iter_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
def _check_is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||
"""Convert a scalar value to a typed projection row."""
|
||||
if value is None:
|
||||
return {
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
if isinstance(value, bool):
|
||||
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return {"key": key, "ordinal": ordinal, "val_num": num}
|
||||
if isinstance(value, str):
|
||||
return {"key": key, "ordinal": ordinal, "val_str": value}
|
||||
return {"key": key, "ordinal": ordinal, "val_json": value}
|
||||
|
||||
|
||||
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
if value is None:
|
||||
return [_scalar_to_row(key, 0, None)]
|
||||
|
||||
if _check_is_scalar(value):
|
||||
return [_scalar_to_row(key, 0, value)]
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(_check_is_scalar(x) for x in value):
|
||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
||||
|
||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def insert_asset_info(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
owner_id: str,
|
||||
name: str,
|
||||
preview_id: str | None = None,
|
||||
) -> AssetInfo | None:
|
||||
"""Insert a new AssetInfo. Returns None if unique constraint violated."""
|
||||
now = get_utc_now()
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset_id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
return info
|
||||
except IntegrityError:
|
||||
return None
|
||||
|
||||
|
||||
def get_or_create_asset_info(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
owner_id: str,
|
||||
name: str,
|
||||
preview_id: str | None = None,
|
||||
) -> tuple[AssetInfo, bool]:
|
||||
"""Get existing or create new AssetInfo. Returns (info, created)."""
|
||||
info = insert_asset_info(
|
||||
session,
|
||||
asset_id=asset_id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
if info:
|
||||
return info, True
|
||||
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset_id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
.unique()
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
if not existing:
|
||||
raise RuntimeError("Failed to find AssetInfo after insert conflict.")
|
||||
return existing, False
|
||||
|
||||
|
||||
def update_asset_info_timestamps(
|
||||
session: Session,
|
||||
asset_info: AssetInfo,
|
||||
preview_id: str | None = None,
|
||||
) -> None:
|
||||
"""Update timestamps and optionally preview_id on existing AssetInfo."""
|
||||
now = get_utc_now()
|
||||
if preview_id and asset_info.preview_id != preview_id:
|
||||
asset_info.preview_id = preview_id
|
||||
asset_info.updated_at = now
|
||||
if asset_info.last_access_time < now:
|
||||
asset_info.last_access_time = now
|
||||
session.flush()
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
build_visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
build_visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
def update_asset_info_access_time(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or get_utc_now()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts
|
||||
)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def update_asset_info_name(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Update the name of an AssetInfo."""
|
||||
now = get_utc_now()
|
||||
session.execute(
|
||||
sa.update(AssetInfo)
|
||||
.where(AssetInfo.id == asset_info_id)
|
||||
.values(name=name, updated_at=now)
|
||||
)
|
||||
|
||||
|
||||
def update_asset_info_updated_at(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
) -> None:
|
||||
"""Update the updated_at timestamp of an AssetInfo."""
|
||||
ts = ts or get_utc_now()
|
||||
session.execute(
|
||||
sa.update(AssetInfo).where(AssetInfo.id == asset_info_id).values(updated_at=ts)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_metadata(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
|
||||
session.execute(
|
||||
delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in convert_metadata_to_rows(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
build_visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
|
||||
|
||||
def bulk_insert_asset_infos_ignore_conflicts(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
) -> None:
|
||||
"""Bulk insert AssetInfo rows with ON CONFLICT DO NOTHING.
|
||||
|
||||
Each dict should have: id, owner_id, name, asset_id, preview_id,
|
||||
user_metadata, created_at, updated_at, last_access_time
|
||||
"""
|
||||
if not rows:
|
||||
return
|
||||
ins = sqlite.insert(AssetInfo).on_conflict_do_nothing(
|
||||
index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]
|
||||
)
|
||||
for chunk in iter_chunks(rows, calculate_rows_per_statement(9)):
|
||||
session.execute(ins, chunk)
|
||||
|
||||
|
||||
def get_asset_info_ids_by_ids(
|
||||
session: Session,
|
||||
info_ids: list[str],
|
||||
) -> set[str]:
|
||||
"""Query to find which AssetInfo IDs exist in the database."""
|
||||
if not info_ids:
|
||||
return set()
|
||||
|
||||
found: set[str] = set()
|
||||
for chunk in iter_chunks(info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(select(AssetInfo.id).where(AssetInfo.id.in_(chunk)))
|
||||
found.update(result.scalars().all())
|
||||
return found
|
||||
351
app/assets/database/queries/cache_state.py
Normal file
351
app/assets/database/queries/cache_state.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import os
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
from app.assets.database.queries.common import (
|
||||
MAX_BIND_PARAMS,
|
||||
calculate_rows_per_statement,
|
||||
iter_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string
|
||||
|
||||
|
||||
class CacheStateRow(NamedTuple):
|
||||
"""Row from cache state query with joined asset data."""
|
||||
|
||||
state_id: int
|
||||
file_path: str
|
||||
mtime_ns: int | None
|
||||
needs_verify: bool
|
||||
asset_id: str
|
||||
asset_hash: str | None
|
||||
size_bytes: int
|
||||
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
(
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def upsert_cache_state(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
file_path: str,
|
||||
mtime_ns: int,
|
||||
) -> tuple[bool, bool]:
|
||||
"""Upsert a cache state by file_path. Returns (created, updated).
|
||||
|
||||
Also restores cache states that were previously marked as missing.
|
||||
"""
|
||||
vals = {
|
||||
"asset_id": asset_id,
|
||||
"file_path": file_path,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
"is_missing": False,
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
res = session.execute(ins)
|
||||
created = int(res.rowcount or 0) > 0
|
||||
|
||||
if created:
|
||||
return True, False
|
||||
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == file_path)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset_id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
AssetCacheState.is_missing == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False)
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
updated = int(res2.rowcount or 0) > 0
|
||||
return False, updated
|
||||
|
||||
|
||||
def mark_cache_states_missing_outside_prefixes(
|
||||
session: Session, valid_prefixes: list[str]
|
||||
) -> int:
|
||||
"""Mark cache states as missing when file_path doesn't match any valid prefix.
|
||||
|
||||
This is a non-destructive soft-delete that preserves user metadata.
|
||||
Cache states can be restored if the file reappears in a future scan.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
valid_prefixes: List of absolute directory prefixes that are valid
|
||||
|
||||
Returns:
|
||||
Number of cache states marked as missing
|
||||
"""
|
||||
if not valid_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
|
||||
result = session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(~matches_valid_prefix)
|
||||
.where(AssetCacheState.is_missing == False) # noqa: E712
|
||||
.values(is_missing=True)
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def restore_cache_states_by_paths(session: Session, file_paths: list[str]) -> int:
|
||||
"""Restore cache states that were previously marked as missing.
|
||||
|
||||
Called when a file path is re-scanned and found to exist.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
file_paths: List of file paths that exist and should be restored
|
||||
|
||||
Returns:
|
||||
Number of cache states restored
|
||||
"""
|
||||
if not file_paths:
|
||||
return 0
|
||||
|
||||
result = session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path.in_(file_paths))
|
||||
.where(AssetCacheState.is_missing == True) # noqa: E712
|
||||
.values(is_missing=False)
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
|
||||
"""Get IDs of unhashed assets (hash=None) with no active cache states.
|
||||
|
||||
An asset is considered unreferenced if it has no cache states,
|
||||
or all its cache states are marked as missing.
|
||||
|
||||
Returns:
|
||||
List of asset IDs that are unreferenced
|
||||
"""
|
||||
active_cache_state_exists = (
|
||||
sa.select(sa.literal(1))
|
||||
.where(AssetCacheState.asset_id == Asset.id)
|
||||
.where(AssetCacheState.is_missing == False) # noqa: E712
|
||||
.correlate(Asset)
|
||||
.exists()
|
||||
)
|
||||
unreferenced_subq = sa.select(Asset.id).where(
|
||||
Asset.hash.is_(None), ~active_cache_state_exists
|
||||
)
|
||||
return [row[0] for row in session.execute(unreferenced_subq).all()]
|
||||
|
||||
|
||||
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
|
||||
"""Delete assets and their AssetInfos by ID.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
asset_ids: List of asset IDs to delete
|
||||
|
||||
Returns:
|
||||
Number of assets deleted
|
||||
"""
|
||||
if not asset_ids:
|
||||
return 0
|
||||
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
|
||||
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def get_cache_states_for_prefixes(
|
||||
session: Session,
|
||||
prefixes: list[str],
|
||||
*,
|
||||
include_missing: bool = False,
|
||||
) -> list[CacheStateRow]:
|
||||
"""Get all cache states with paths matching any of the given prefixes.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
prefixes: List of absolute directory prefixes to match
|
||||
include_missing: If False (default), exclude cache states marked as missing
|
||||
|
||||
Returns:
|
||||
List of cache state rows with joined asset data, ordered by asset_id, state_id
|
||||
"""
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
query = (
|
||||
sa.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
)
|
||||
|
||||
if not include_missing:
|
||||
query = query.where(AssetCacheState.is_missing == False) # noqa: E712
|
||||
|
||||
rows = session.execute(
|
||||
query.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
).all()
|
||||
|
||||
return [
|
||||
CacheStateRow(
|
||||
state_id=row[0],
|
||||
file_path=row[1],
|
||||
mtime_ns=row[2],
|
||||
needs_verify=row[3],
|
||||
asset_id=row[4],
|
||||
asset_hash=row[5],
|
||||
size_bytes=int(row[6] or 0),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
def bulk_update_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
|
||||
"""Set needs_verify flag for multiple cache states.
|
||||
|
||||
Returns: Number of rows updated
|
||||
"""
|
||||
if not state_ids:
|
||||
return 0
|
||||
result = session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(state_ids))
|
||||
.values(needs_verify=value)
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def bulk_update_is_missing(session: Session, state_ids: list[int], value: bool) -> int:
|
||||
"""Set is_missing flag for multiple cache states.
|
||||
|
||||
Returns: Number of rows updated
|
||||
"""
|
||||
if not state_ids:
|
||||
return 0
|
||||
result = session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(state_ids))
|
||||
.values(is_missing=value)
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
|
||||
"""Delete cache states by their IDs.
|
||||
|
||||
Returns: Number of rows deleted
|
||||
"""
|
||||
if not state_ids:
|
||||
return 0
|
||||
result = session.execute(
|
||||
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
|
||||
"""Delete a seed asset (hash is None) and its AssetInfos.
|
||||
|
||||
Returns: True if asset was deleted, False if not found
|
||||
"""
|
||||
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
|
||||
asset = session.get(Asset, asset_id)
|
||||
if asset:
|
||||
session.delete(asset)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def bulk_insert_cache_states_ignore_conflicts(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
) -> None:
|
||||
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
|
||||
|
||||
Each dict should have: asset_id, file_path, mtime_ns
|
||||
The is_missing field is automatically set to False for new inserts.
|
||||
"""
|
||||
if not rows:
|
||||
return
|
||||
enriched_rows = [{**row, "is_missing": False} for row in rows]
|
||||
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
|
||||
index_elements=[AssetCacheState.file_path]
|
||||
)
|
||||
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(4)):
|
||||
session.execute(ins, chunk)
|
||||
|
||||
|
||||
def get_cache_states_by_paths_and_asset_ids(
|
||||
session: Session,
|
||||
path_to_asset: dict[str, str],
|
||||
) -> set[str]:
|
||||
"""Query cache states to find paths where our asset_id won the insert.
|
||||
|
||||
Args:
|
||||
path_to_asset: Mapping of file_path -> asset_id we tried to insert
|
||||
|
||||
Returns:
|
||||
Set of file_paths where our asset_id is present
|
||||
"""
|
||||
if not path_to_asset:
|
||||
return set()
|
||||
|
||||
paths = list(path_to_asset.keys())
|
||||
winners: set[str] = set()
|
||||
|
||||
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
select(AssetCacheState.file_path).where(
|
||||
AssetCacheState.file_path.in_(chunk),
|
||||
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
|
||||
)
|
||||
)
|
||||
winners.update(result.scalars().all())
|
||||
|
||||
return winners
|
||||
37
app/assets/database/queries/common.py
Normal file
37
app/assets/database/queries/common.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Shared utilities for database query modules."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.assets.database.models import AssetInfo
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
def calculate_rows_per_statement(cols: int) -> int:
|
||||
"""Calculate how many rows can fit in one statement given column count."""
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def iter_chunks(seq, n: int):
|
||||
"""Yield successive n-sized chunks from seq."""
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i : i + n]
|
||||
|
||||
|
||||
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
|
||||
"""Yield chunks of rows sized to fit within bind param limits."""
|
||||
if not rows:
|
||||
return
|
||||
rows_per_stmt = calculate_rows_per_statement(cols_per_row)
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i : i + rows_per_stmt]
|
||||
|
||||
|
||||
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
349
app/assets/database/queries/tags.py
Normal file
349
app/assets/database/queries/tags.py
Normal file
@@ -0,0 +1,349 @@
|
||||
from typing import Iterable, Sequence, TypedDict
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.database.queries.common import (
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
class AddTagsDict(TypedDict):
|
||||
added: list[str]
|
||||
already_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
class RemoveTagsDict(TypedDict):
|
||||
removed: list[str]
|
||||
not_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
class SetTagsDict(TypedDict):
|
||||
added: list[str]
|
||||
removed: list[str]
|
||||
total: list[str]
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> SetTagsDict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: AssetInfo | None = None,
|
||||
) -> AddTagsDict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> RemoveTagsDict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sa.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sa.literal("missing").label("tag_name"),
|
||||
sa.literal(origin).label("origin"),
|
||||
sa.literal(get_utc_now()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sa.not_(
|
||||
sa.exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == "missing")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(
|
||||
sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)
|
||||
),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tag_rows: List of dicts with keys: asset_info_id, tag_name, origin, added_at
|
||||
meta_rows: List of dicts with keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_tags = sqlite.insert(AssetInfoTag).on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
|
||||
)
|
||||
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
|
||||
session.execute(ins_tags, chunk)
|
||||
|
||||
if meta_rows:
|
||||
ins_meta = sqlite.insert(AssetInfoMeta).on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetInfoMeta.asset_info_id,
|
||||
AssetInfoMeta.key,
|
||||
AssetInfoMeta.ordinal,
|
||||
]
|
||||
)
|
||||
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
|
||||
session.execute(ins_meta, chunk)
|
||||
@@ -1,62 +0,0 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
@@ -1,52 +1,36 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
from typing import Literal, Sequence
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
def select_best_live_path(states: Sequence) -> str:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
alive = [
|
||||
s
|
||||
for s in states
|
||||
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
|
||||
]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
ALLOWED_ROOTS: tuple[Literal["models", "input", "output"], ...] = (
|
||||
"models",
|
||||
"input",
|
||||
"output",
|
||||
)
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
|
||||
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
@@ -54,173 +38,11 @@ def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
def get_utc_now() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
@@ -229,84 +51,3 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@@ -1,516 +0,0 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
@@ -1,263 +1,396 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
import time
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
from app.assets.database.queries import (
|
||||
add_missing_tag_for_asset_id,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
delete_cache_states_by_ids,
|
||||
delete_orphaned_seed_asset,
|
||||
ensure_tags_exist,
|
||||
get_cache_states_for_prefixes,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
from app.assets.services.bulk_ingest import (
|
||||
SeedAssetSpec,
|
||||
batch_insert_seed_assets,
|
||||
mark_assets_missing_outside_prefixes,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.hashing import compute_blake3_hash
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
)
|
||||
from app.database.db import create_session, dependencies_available
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
class _StateInfo(TypedDict):
|
||||
sid: int
|
||||
fp: str
|
||||
exists: bool
|
||||
fast_ok: bool
|
||||
needs_verify: bool
|
||||
|
||||
|
||||
class _AssetAccumulator(TypedDict):
|
||||
hash: str | None
|
||||
size_db: int
|
||||
states: list[_StateInfo]
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
|
||||
|
||||
def get_prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def get_all_known_prefixes() -> list[str]:
|
||||
"""Get all known asset prefixes across all root types."""
|
||||
all_roots: tuple[RootType, ...] = ("models", "input", "output")
|
||||
return [
|
||||
os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root)
|
||||
]
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(ValueError):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
|
||||
def sync_cache_states_with_filesystem(
|
||||
session,
|
||||
root: RootType,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Reconcile cache states with filesystem for a root.
|
||||
|
||||
- Toggle needs_verify per state using fast mtime/size check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
root: Root type to scan
|
||||
collect_existing_paths: If True, return set of surviving file paths
|
||||
update_missing_tags: If True, update 'missing' tags based on file status
|
||||
|
||||
Returns:
|
||||
Set of surviving absolute paths if collect_existing_paths=True, else None
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
rows = get_cache_states_for_prefixes(
|
||||
session, prefixes, include_missing=update_missing_tags
|
||||
)
|
||||
|
||||
by_asset: dict[str, _AssetAccumulator] = {}
|
||||
for row in rows:
|
||||
acc = by_asset.get(row.asset_id)
|
||||
if acc is None:
|
||||
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "states": []}
|
||||
by_asset[row.asset_id] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = verify_file_unchanged(
|
||||
mtime_db=row.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(row.file_path, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except PermissionError:
|
||||
exists = True
|
||||
logging.debug("Permission denied accessing %s", row.file_path)
|
||||
except OSError as e:
|
||||
exists = False
|
||||
logging.debug("OSError checking %s: %s", row.file_path, e)
|
||||
|
||||
acc["states"].append(
|
||||
{
|
||||
"sid": row.state_id,
|
||||
"fp": row.file_path,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": row.needs_verify,
|
||||
}
|
||||
)
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
to_mark_missing: list[int] = []
|
||||
to_clear_missing: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
to_mark_missing.append(s["sid"])
|
||||
continue
|
||||
if s["fast_ok"]:
|
||||
to_clear_missing.append(s["sid"])
|
||||
if s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing:
|
||||
delete_orphaned_seed_asset(session, aid)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok:
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=aid)
|
||||
except Exception as e:
|
||||
logging.warning("Failed to remove missing tag for asset %s: %s", aid, e)
|
||||
elif update_missing_tags:
|
||||
try:
|
||||
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
|
||||
except Exception as e:
|
||||
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
delete_cache_states_by_ids(session, stale_state_ids)
|
||||
stale_set = set(stale_state_ids)
|
||||
to_mark_missing = [sid for sid in to_mark_missing if sid not in stale_set]
|
||||
bulk_update_is_missing(session, to_mark_missing, value=True)
|
||||
bulk_update_is_missing(session, to_clear_missing, value=False)
|
||||
bulk_update_needs_verify(session, to_set_verify, value=True)
|
||||
bulk_update_needs_verify(session, to_clear_verify, value=False)
|
||||
|
||||
return survivors if collect_existing_paths else None
|
||||
|
||||
|
||||
def sync_root_safely(root: RootType) -> set[str]:
|
||||
"""Sync a single root's cache states with the filesystem.
|
||||
|
||||
Returns survivors (existing paths) or empty set on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
survivors = sync_cache_states_with_filesystem(
|
||||
sess,
|
||||
root,
|
||||
collect_existing_paths=True,
|
||||
update_missing_tags=True,
|
||||
)
|
||||
sess.commit()
|
||||
return survivors or set()
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", root, e)
|
||||
return set()
|
||||
|
||||
|
||||
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
|
||||
"""Mark cache states as missing when outside the given prefixes.
|
||||
|
||||
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
count = mark_assets_missing_outside_prefixes(sess, prefixes)
|
||||
sess.commit()
|
||||
return count
|
||||
except Exception as e:
|
||||
logging.exception("marking missing assets failed: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
|
||||
"""Collect all file paths for the given roots."""
|
||||
paths: list[str] = []
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
|
||||
return paths
|
||||
|
||||
|
||||
def build_asset_specs(
|
||||
paths: list[str],
|
||||
existing_paths: set[str],
|
||||
enable_metadata_extraction: bool = True,
|
||||
compute_hashes: bool = False,
|
||||
) -> tuple[list[SeedAssetSpec], set[str], int]:
|
||||
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
|
||||
|
||||
Args:
|
||||
paths: List of file paths to process
|
||||
existing_paths: Set of paths that already exist in the database
|
||||
enable_metadata_extraction: If True, extract tier 1 & 2 metadata from files
|
||||
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
|
||||
"""
|
||||
specs: list[SeedAssetSpec] = []
|
||||
tag_pool: set[str] = set()
|
||||
skipped = 0
|
||||
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
continue
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
rel_fname = compute_relative_filename(abs_p)
|
||||
|
||||
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
|
||||
metadata = None
|
||||
if enable_metadata_extraction:
|
||||
metadata = extract_file_metadata(
|
||||
abs_p,
|
||||
stat_result=stat_p,
|
||||
enable_safetensors=True,
|
||||
relative_filename=rel_fname,
|
||||
)
|
||||
|
||||
# Compute hash if requested
|
||||
asset_hash: str | None = None
|
||||
if compute_hashes:
|
||||
try:
|
||||
digest = compute_blake3_hash(abs_p)
|
||||
asset_hash = "blake3:" + digest
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", abs_p, e)
|
||||
|
||||
mime_type = metadata.content_type if metadata else None
|
||||
if mime_type is None:
|
||||
pass
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": get_mtime_ns(stat_p),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": rel_fname,
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
|
||||
return specs, tag_pool, skipped
|
||||
|
||||
|
||||
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
"""Insert asset specs into database, returning count of created infos."""
|
||||
if not specs:
|
||||
return 0
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||
sess.commit()
|
||||
return result.inserted_infos
|
||||
|
||||
|
||||
def seed_assets(
|
||||
roots: tuple[RootType, ...],
|
||||
enable_logging: bool = False,
|
||||
compute_hashes: bool = False,
|
||||
) -> None:
|
||||
"""Scan the given roots and seed the assets into the database.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan (models, input, output)
|
||||
enable_logging: If True, log progress and completion messages
|
||||
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
|
||||
|
||||
Note: This function does not mark missing assets. Call mark_missing_outside_prefixes_safely
|
||||
separately if cleanup is needed.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
existing_paths.update(sync_root_safely(r))
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
paths = collect_paths_for_roots(roots)
|
||||
specs, tag_pool, skipped_existing = build_asset_specs(
|
||||
paths, existing_paths, compute_hashes=compute_hashes
|
||||
)
|
||||
created = insert_asset_specs(specs, tag_pool)
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
433
app/assets/seeder.py
Normal file
433
app/assets/seeder.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Background asset seeder with thread management and cancellation support."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
from app.assets.scanner import (
|
||||
RootType,
|
||||
build_asset_specs,
|
||||
collect_paths_for_roots,
|
||||
get_all_known_prefixes,
|
||||
get_prefixes_for_root,
|
||||
insert_asset_specs,
|
||||
mark_missing_outside_prefixes_safely,
|
||||
sync_root_safely,
|
||||
)
|
||||
from app.database.db import dependencies_available
|
||||
|
||||
|
||||
class State(Enum):
|
||||
"""Seeder state machine states."""
|
||||
|
||||
IDLE = "IDLE"
|
||||
RUNNING = "RUNNING"
|
||||
CANCELLING = "CANCELLING"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Progress:
|
||||
"""Progress information for a scan operation."""
|
||||
|
||||
scanned: int = 0
|
||||
total: int = 0
|
||||
created: int = 0
|
||||
skipped: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanStatus:
|
||||
"""Current status of the asset seeder."""
|
||||
|
||||
state: State
|
||||
progress: Progress | None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[Progress], None]
|
||||
|
||||
|
||||
class AssetSeeder:
|
||||
"""Singleton class managing background asset scanning.
|
||||
|
||||
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
|
||||
Each scan creates a new thread that exits when complete.
|
||||
"""
|
||||
|
||||
_instance: "AssetSeeder | None" = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "AssetSeeder":
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
self._lock = threading.Lock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._errors: list[str] = []
|
||||
self._thread: threading.Thread | None = None
|
||||
self._cancel_event = threading.Event()
|
||||
self._roots: tuple[RootType, ...] = ()
|
||||
self._compute_hashes: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
|
||||
def start(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool = False,
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start a background scan for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan (models, input, output)
|
||||
progress_callback: Optional callback called with progress updates
|
||||
prune_first: If True, prune orphaned assets before scanning
|
||||
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
return False
|
||||
self._state = State.RUNNING
|
||||
self._progress = Progress()
|
||||
self._errors = []
|
||||
self._roots = roots
|
||||
self._prune_first = prune_first
|
||||
self._compute_hashes = compute_hashes
|
||||
self._progress_callback = progress_callback
|
||||
self._cancel_event.clear()
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_scan,
|
||||
name="AssetSeeder",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
return True
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
Returns:
|
||||
True if cancellation was requested, False if not running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.RUNNING:
|
||||
return False
|
||||
self._state = State.CANCELLING
|
||||
self._cancel_event.set()
|
||||
return True
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""Wait for the current scan to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait, or None for no timeout
|
||||
|
||||
Returns:
|
||||
True if scan completed, False if timeout expired or no scan running
|
||||
"""
|
||||
with self._lock:
|
||||
thread = self._thread
|
||||
if thread is None:
|
||||
return True
|
||||
thread.join(timeout=timeout)
|
||||
return not thread.is_alive()
|
||||
|
||||
def get_status(self) -> ScanStatus:
|
||||
"""Get the current status and progress of the seeder."""
|
||||
with self._lock:
|
||||
return ScanStatus(
|
||||
state=self._state,
|
||||
progress=Progress(
|
||||
scanned=self._progress.scanned,
|
||||
total=self._progress.total,
|
||||
created=self._progress.created,
|
||||
skipped=self._progress.skipped,
|
||||
)
|
||||
if self._progress
|
||||
else None,
|
||||
errors=list(self._errors),
|
||||
)
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown: cancel any running scan and wait for thread.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for thread to exit
|
||||
"""
|
||||
self.cancel()
|
||||
self.wait(timeout=timeout)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
def mark_missing_outside_prefixes(self) -> int:
|
||||
"""Mark cache states as missing when outside all known root prefixes.
|
||||
|
||||
This is a non-destructive soft-delete operation. Assets and their
|
||||
metadata are preserved, but cache states are flagged as missing.
|
||||
They can be restored if the file reappears in a future scan.
|
||||
|
||||
This operation is decoupled from scanning to prevent partial scans
|
||||
from accidentally marking assets belonging to other roots.
|
||||
|
||||
Should be called explicitly when cleanup is desired, typically after
|
||||
a full scan of all roots or during maintenance.
|
||||
|
||||
Returns:
|
||||
Number of cache states marked as missing, or 0 if dependencies
|
||||
unavailable or a scan is currently running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
logging.warning(
|
||||
"Cannot mark missing assets while scan is running"
|
||||
)
|
||||
return 0
|
||||
self._state = State.RUNNING
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
logging.warning(
|
||||
"Database dependencies not available, skipping mark missing"
|
||||
)
|
||||
return 0
|
||||
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d cache states as missing", marked)
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._state = State.IDLE
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||
"""Emit a WebSocket event if server is available."""
|
||||
try:
|
||||
from server import PromptServer
|
||||
|
||||
if hasattr(PromptServer, "instance") and PromptServer.instance:
|
||||
PromptServer.instance.send_sync(event_type, data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
scanned: int | None = None,
|
||||
total: int | None = None,
|
||||
created: int | None = None,
|
||||
skipped: int | None = None,
|
||||
) -> None:
|
||||
"""Update progress counters (thread-safe)."""
|
||||
callback: ProgressCallback | None = None
|
||||
progress: Progress | None = None
|
||||
|
||||
with self._lock:
|
||||
if self._progress is None:
|
||||
return
|
||||
if scanned is not None:
|
||||
self._progress.scanned = scanned
|
||||
if total is not None:
|
||||
self._progress.total = total
|
||||
if created is not None:
|
||||
self._progress.created = created
|
||||
if skipped is not None:
|
||||
self._progress.skipped = skipped
|
||||
if self._progress_callback:
|
||||
callback = self._progress_callback
|
||||
progress = Progress(
|
||||
scanned=self._progress.scanned,
|
||||
total=self._progress.total,
|
||||
created=self._progress.created,
|
||||
skipped=self._progress.skipped,
|
||||
)
|
||||
|
||||
if callback and progress:
|
||||
try:
|
||||
callback(progress)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _add_error(self, message: str) -> None:
|
||||
"""Add an error message (thread-safe)."""
|
||||
with self._lock:
|
||||
self._errors.append(message)
|
||||
|
||||
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
|
||||
"""Log the directories that will be scanned."""
|
||||
import folder_paths
|
||||
|
||||
for root in roots:
|
||||
if root == "models":
|
||||
logging.info(
|
||||
"Asset scan [models] directory: %s",
|
||||
os.path.abspath(folder_paths.models_dir),
|
||||
)
|
||||
else:
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if prefixes:
|
||||
logging.info("Asset scan [%s] directories: %s", root, prefixes)
|
||||
|
||||
def _run_scan(self) -> None:
|
||||
"""Main scan loop running in background thread."""
|
||||
t_start = time.perf_counter()
|
||||
roots = self._roots
|
||||
cancelled = False
|
||||
total_created = 0
|
||||
skipped_existing = 0
|
||||
total_paths = 0
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
self._add_error("Database dependencies not available")
|
||||
self._emit_event(
|
||||
"assets.seed.error",
|
||||
{"message": "Database dependencies not available"},
|
||||
)
|
||||
return
|
||||
|
||||
if self._prune_first:
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d cache states as missing before scan", marked)
|
||||
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled after pruning phase")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._log_scan_config(roots)
|
||||
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled during sync phase")
|
||||
cancelled = True
|
||||
return
|
||||
existing_paths.update(sync_root_safely(r))
|
||||
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled after sync phase")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
paths = collect_paths_for_roots(roots)
|
||||
total_paths = len(paths)
|
||||
self._update_progress(total=total_paths)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "total": total_paths},
|
||||
)
|
||||
|
||||
specs, tag_pool, skipped_existing = build_asset_specs(
|
||||
paths, existing_paths, compute_hashes=self._compute_hashes
|
||||
)
|
||||
self._update_progress(skipped=skipped_existing)
|
||||
|
||||
if self._is_cancelled():
|
||||
logging.info("Asset scan cancelled after building specs")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
batch_size = 500
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
for i in range(0, len(specs), batch_size):
|
||||
if self._is_cancelled():
|
||||
logging.info(
|
||||
"Asset scan cancelled after %d/%d files (created=%d)",
|
||||
i,
|
||||
len(specs),
|
||||
total_created,
|
||||
)
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
batch = specs[i : i + batch_size]
|
||||
batch_tags = {t for spec in batch for t in spec["tags"]}
|
||||
try:
|
||||
created = insert_asset_specs(batch, batch_tags)
|
||||
total_created += created
|
||||
except Exception as e:
|
||||
self._add_error(f"Batch insert failed at offset {i}: {e}")
|
||||
logging.exception("Batch insert failed at offset %d", i)
|
||||
|
||||
scanned = i + len(batch)
|
||||
now = time.perf_counter()
|
||||
self._update_progress(scanned=scanned, created=total_created)
|
||||
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{
|
||||
"scanned": scanned,
|
||||
"total": len(specs),
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
self._update_progress(scanned=len(specs), created=total_created)
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
logging.info(
|
||||
"Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, total=%d)",
|
||||
roots,
|
||||
elapsed,
|
||||
total_created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.completed",
|
||||
{
|
||||
"scanned": len(specs),
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
"skipped": skipped_existing,
|
||||
"elapsed": round(elapsed, 3),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._add_error(f"Scan failed: {e}")
|
||||
logging.exception("Asset scan failed")
|
||||
self._emit_event("assets.seed.error", {"message": str(e)})
|
||||
finally:
|
||||
if cancelled:
|
||||
self._emit_event(
|
||||
"assets.seed.cancelled",
|
||||
{
|
||||
"scanned": self._progress.scanned if self._progress else 0,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._state = State.IDLE
|
||||
|
||||
|
||||
asset_seeder = AssetSeeder()
|
||||
89
app/assets/services/__init__.py
Normal file
89
app/assets/services/__init__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from app.assets.services.asset_management import (
|
||||
asset_exists,
|
||||
delete_asset_reference,
|
||||
get_asset_by_hash,
|
||||
get_asset_detail,
|
||||
list_assets_page,
|
||||
resolve_asset_for_download,
|
||||
set_asset_preview,
|
||||
update_asset_metadata,
|
||||
)
|
||||
from app.assets.services.bulk_ingest import (
|
||||
BulkInsertResult,
|
||||
batch_insert_seed_assets,
|
||||
cleanup_unreferenced_assets,
|
||||
mark_assets_missing_outside_prefixes,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
get_size_and_mtime_ns,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
AddTagsResult,
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetInfoData,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
IngestResult,
|
||||
ListAssetsResult,
|
||||
RegisterAssetResult,
|
||||
RemoveTagsResult,
|
||||
SetTagsResult,
|
||||
TagUsage,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
)
|
||||
from app.assets.services.tagging import (
|
||||
apply_tags,
|
||||
list_tags,
|
||||
remove_tags,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsResult",
|
||||
"AssetData",
|
||||
"AssetDetailResult",
|
||||
"AssetInfoData",
|
||||
"AssetSummaryData",
|
||||
"BulkInsertResult",
|
||||
"DependencyMissingError",
|
||||
"DownloadResolutionResult",
|
||||
"HashMismatchError",
|
||||
"IngestResult",
|
||||
"ListAssetsResult",
|
||||
"RegisterAssetResult",
|
||||
"RemoveTagsResult",
|
||||
"SetTagsResult",
|
||||
"TagUsage",
|
||||
"UploadResult",
|
||||
"UserMetadata",
|
||||
"apply_tags",
|
||||
"asset_exists",
|
||||
"batch_insert_seed_assets",
|
||||
"create_from_hash",
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
"list_files_recursively",
|
||||
"list_tags",
|
||||
"cleanup_unreferenced_assets",
|
||||
"mark_assets_missing_outside_prefixes",
|
||||
"remove_tags",
|
||||
"resolve_asset_for_download",
|
||||
"set_asset_preview",
|
||||
"update_asset_metadata",
|
||||
"upload_from_temp_path",
|
||||
"verify_file_unchanged",
|
||||
]
|
||||
292
app/assets/services/asset_management.py
Normal file
292
app/assets/services/asset_management.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_by_hash as queries_get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
list_asset_infos_page,
|
||||
list_cache_states_by_asset_id,
|
||||
set_asset_info_metadata,
|
||||
set_asset_info_preview,
|
||||
set_asset_info_tags,
|
||||
update_asset_info_access_time,
|
||||
update_asset_info_name,
|
||||
update_asset_info_updated_at,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path
|
||||
from app.assets.services.path_utils import compute_filename_for_asset
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
ListAssetsResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_info_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def get_asset_detail(
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult | None:
|
||||
with create_session() as session:
|
||||
result = fetch_asset_info_asset_and_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
info, asset, tags = result
|
||||
return AssetDetailResult(
|
||||
info=extract_info_data(info),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def update_asset_metadata(
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
info = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info.owner_id and info.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
update_asset_info_name(session, asset_info_id=asset_info_id, name=name)
|
||||
touched = True
|
||||
|
||||
computed_filename = compute_filename_for_asset(session, info.asset_id)
|
||||
|
||||
new_meta: dict | None = None
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
elif computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
|
||||
if new_meta is not None:
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
set_asset_info_metadata(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
update_asset_info_updated_at(session, asset_info_id=asset_info_id)
|
||||
|
||||
result = fetch_asset_info_asset_and_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during update")
|
||||
|
||||
info, asset, tag_list = result
|
||||
detail = AssetDetailResult(
|
||||
info=extract_info_data(info),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_list,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def delete_asset_reference(
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
|
||||
deleted = delete_asset_info_by_id(
|
||||
session, asset_info_id=asset_info_id, owner_id=owner_id
|
||||
)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# Orphaned asset - delete it and its files
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [
|
||||
s.file_path for s in (states or []) if getattr(s, "file_path", None)
|
||||
]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Delete files after commit
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
result = fetch_asset_info_asset_and_tags(
|
||||
session, asset_info_id=asset_info_id, owner_id=owner_id
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
|
||||
info, asset, tags = result
|
||||
detail = AssetDetailResult(
|
||||
info=extract_info_data(info),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> ListAssetsResult:
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for info in infos:
|
||||
items.append(
|
||||
AssetSummaryData(
|
||||
info=extract_info_data(info),
|
||||
asset=extract_asset_data(info.asset),
|
||||
tags=tag_map.get(info.id, []),
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(
|
||||
session, asset_info_id=asset_info_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = select_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError(
|
||||
f"No live path for AssetInfo {asset_info_id} (asset id={asset.id}, name={info.name})"
|
||||
)
|
||||
|
||||
update_asset_info_access_time(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = (
|
||||
asset.mime_type
|
||||
or mimetypes.guess_type(info.name or abs_path)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=download_name,
|
||||
)
|
||||
338
app/assets/services/bulk_ingest.py
Normal file
338
app/assets/services/bulk_ingest.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.queries import (
|
||||
bulk_insert_asset_infos_ignore_conflicts,
|
||||
bulk_insert_assets,
|
||||
bulk_insert_cache_states_ignore_conflicts,
|
||||
bulk_insert_tags_and_meta,
|
||||
delete_assets_by_ids,
|
||||
get_asset_info_ids_by_ids,
|
||||
get_cache_states_by_paths_and_asset_ids,
|
||||
get_existing_asset_ids,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
mark_cache_states_missing_outside_prefixes,
|
||||
restore_cache_states_by_paths,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.assets.services.metadata_extract import ExtractedMetadata
|
||||
|
||||
|
||||
class SeedAssetSpec(TypedDict):
|
||||
"""Spec for seeding an asset from filesystem."""
|
||||
|
||||
abs_path: str
|
||||
size_bytes: int
|
||||
mtime_ns: int
|
||||
info_name: str
|
||||
tags: list[str]
|
||||
fname: str
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
"""Row data for inserting an Asset."""
|
||||
|
||||
id: str
|
||||
hash: str | None
|
||||
size_bytes: int
|
||||
mime_type: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CacheStateRow(TypedDict):
|
||||
"""Row data for inserting a CacheState."""
|
||||
|
||||
asset_id: str
|
||||
file_path: str
|
||||
mtime_ns: int
|
||||
|
||||
|
||||
class AssetInfoRow(TypedDict):
|
||||
"""Row data for inserting an AssetInfo."""
|
||||
|
||||
id: str
|
||||
owner_id: str
|
||||
name: str
|
||||
asset_id: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
|
||||
|
||||
class AssetInfoRowInternal(TypedDict):
|
||||
"""Internal row data for AssetInfo with extra tracking fields."""
|
||||
|
||||
id: str
|
||||
owner_id: str
|
||||
name: str
|
||||
asset_id: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
_tags: list[str]
|
||||
_filename: str
|
||||
_extracted_metadata: ExtractedMetadata | None
|
||||
|
||||
|
||||
class TagRow(TypedDict):
|
||||
"""Row data for inserting a Tag."""
|
||||
|
||||
asset_info_id: str
|
||||
tag_name: str
|
||||
origin: str
|
||||
added_at: datetime
|
||||
|
||||
|
||||
class MetadataRow(TypedDict):
|
||||
"""Row data for inserting asset metadata."""
|
||||
|
||||
asset_info_id: str
|
||||
key: str
|
||||
ordinal: int
|
||||
val_str: str | None
|
||||
val_num: float | None
|
||||
val_bool: bool | None
|
||||
val_json: dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BulkInsertResult:
|
||||
"""Result of bulk asset insertion."""
|
||||
|
||||
inserted_infos: int
|
||||
won_states: int
|
||||
lost_states: int
|
||||
|
||||
|
||||
def batch_insert_seed_assets(
|
||||
session: Session,
|
||||
specs: list[SeedAssetSpec],
|
||||
owner_id: str = "",
|
||||
) -> BulkInsertResult:
|
||||
"""Seed assets from filesystem specs in batch.
|
||||
|
||||
Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
|
||||
This function orchestrates:
|
||||
1. Insert seed Assets (hash=NULL)
|
||||
2. Claim cache states with ON CONFLICT DO NOTHING
|
||||
3. Query to find winners (paths where our asset_id was inserted)
|
||||
4. Delete Assets for losers (path already claimed by another asset)
|
||||
5. Insert AssetInfo for winners
|
||||
6. Insert tags and metadata for successfully inserted AssetInfos
|
||||
|
||||
Returns:
|
||||
BulkInsertResult with inserted_infos, won_states, lost_states
|
||||
"""
|
||||
if not specs:
|
||||
return BulkInsertResult(inserted_infos=0, won_states=0, lost_states=0)
|
||||
|
||||
current_time = get_utc_now()
|
||||
asset_rows: list[AssetRow] = []
|
||||
cache_state_rows: list[CacheStateRow] = []
|
||||
path_to_asset_id: dict[str, str] = {}
|
||||
asset_id_to_info: dict[str, AssetInfoRowInternal] = {}
|
||||
absolute_path_list: list[str] = []
|
||||
|
||||
for spec in specs:
|
||||
absolute_path = os.path.abspath(spec["abs_path"])
|
||||
asset_id = str(uuid.uuid4())
|
||||
asset_info_id = str(uuid.uuid4())
|
||||
absolute_path_list.append(absolute_path)
|
||||
path_to_asset_id[absolute_path] = asset_id
|
||||
|
||||
mime_type = spec.get("mime_type")
|
||||
if mime_type is None:
|
||||
logging.info("batch_insert_seed_assets: no mime_type for %s", absolute_path)
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": asset_id,
|
||||
"hash": spec.get("hash"),
|
||||
"size_bytes": spec["size_bytes"],
|
||||
"mime_type": mime_type,
|
||||
"created_at": current_time,
|
||||
}
|
||||
)
|
||||
cache_state_rows.append(
|
||||
{
|
||||
"asset_id": asset_id,
|
||||
"file_path": absolute_path,
|
||||
"mtime_ns": spec["mtime_ns"],
|
||||
}
|
||||
)
|
||||
# Build user_metadata from extracted metadata or fallback to filename
|
||||
extracted_metadata = spec.get("metadata")
|
||||
if extracted_metadata:
|
||||
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
|
||||
elif spec["fname"]:
|
||||
user_metadata = {"filename": spec["fname"]}
|
||||
else:
|
||||
user_metadata = None
|
||||
|
||||
asset_id_to_info[asset_id] = {
|
||||
"id": asset_info_id,
|
||||
"owner_id": owner_id,
|
||||
"name": spec["info_name"],
|
||||
"asset_id": asset_id,
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
"_tags": spec["tags"],
|
||||
"_filename": spec["fname"],
|
||||
"_extracted_metadata": extracted_metadata,
|
||||
}
|
||||
|
||||
bulk_insert_assets(session, asset_rows)
|
||||
|
||||
# Filter cache states to only those whose assets were actually inserted
|
||||
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
|
||||
inserted_asset_ids = get_existing_asset_ids(
|
||||
session, [r["asset_id"] for r in cache_state_rows]
|
||||
)
|
||||
cache_state_rows = [
|
||||
r for r in cache_state_rows if r["asset_id"] in inserted_asset_ids
|
||||
]
|
||||
|
||||
bulk_insert_cache_states_ignore_conflicts(session, cache_state_rows)
|
||||
restore_cache_states_by_paths(session, absolute_path_list)
|
||||
winning_paths = get_cache_states_by_paths_and_asset_ids(session, path_to_asset_id)
|
||||
|
||||
all_paths_set = set(absolute_path_list)
|
||||
losing_paths = all_paths_set - winning_paths
|
||||
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
|
||||
|
||||
if lost_asset_ids:
|
||||
delete_assets_by_ids(session, lost_asset_ids)
|
||||
|
||||
if not winning_paths:
|
||||
return BulkInsertResult(
|
||||
inserted_infos=0,
|
||||
won_states=0,
|
||||
lost_states=len(losing_paths),
|
||||
)
|
||||
|
||||
winner_info_rows = [
|
||||
asset_id_to_info[path_to_asset_id[path]] for path in winning_paths
|
||||
]
|
||||
database_info_rows: list[AssetInfoRow] = [
|
||||
{
|
||||
"id": info_row["id"],
|
||||
"owner_id": info_row["owner_id"],
|
||||
"name": info_row["name"],
|
||||
"asset_id": info_row["asset_id"],
|
||||
"preview_id": info_row["preview_id"],
|
||||
"user_metadata": info_row["user_metadata"],
|
||||
"created_at": info_row["created_at"],
|
||||
"updated_at": info_row["updated_at"],
|
||||
"last_access_time": info_row["last_access_time"],
|
||||
}
|
||||
for info_row in winner_info_rows
|
||||
]
|
||||
bulk_insert_asset_infos_ignore_conflicts(session, database_info_rows)
|
||||
|
||||
all_info_ids = [info_row["id"] for info_row in winner_info_rows]
|
||||
inserted_info_ids = get_asset_info_ids_by_ids(session, all_info_ids)
|
||||
|
||||
tag_rows: list[TagRow] = []
|
||||
metadata_rows: list[MetadataRow] = []
|
||||
if inserted_info_ids:
|
||||
for info_row in winner_info_rows:
|
||||
info_id = info_row["id"]
|
||||
if info_id not in inserted_info_ids:
|
||||
continue
|
||||
for tag in info_row["_tags"]:
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_info_id": info_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Use extracted metadata for meta rows if available
|
||||
extracted_metadata = info_row.get("_extracted_metadata")
|
||||
if extracted_metadata:
|
||||
metadata_rows.extend(extracted_metadata.to_meta_rows(info_id))
|
||||
elif info_row["_filename"]:
|
||||
# Fallback: just store filename
|
||||
metadata_rows.append(
|
||||
{
|
||||
"asset_info_id": info_id,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": info_row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
|
||||
|
||||
return BulkInsertResult(
|
||||
inserted_infos=len(inserted_info_ids),
|
||||
won_states=len(winning_paths),
|
||||
lost_states=len(losing_paths),
|
||||
)
|
||||
|
||||
|
||||
def mark_assets_missing_outside_prefixes(
|
||||
session: Session, valid_prefixes: list[str]
|
||||
) -> int:
|
||||
"""Mark cache states as missing when outside valid prefixes.
|
||||
|
||||
This is a non-destructive operation that soft-deletes cache states
|
||||
by setting is_missing=True. User metadata is preserved and assets
|
||||
can be restored if the file reappears in a future scan.
|
||||
|
||||
Note: This does NOT delete
|
||||
unreferenced unhashed assets. Those are preserved so user metadata
|
||||
remains intact even when base directories change.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
valid_prefixes: List of absolute directory prefixes that are valid
|
||||
|
||||
Returns:
|
||||
Number of cache states marked as missing
|
||||
"""
|
||||
return mark_cache_states_missing_outside_prefixes(session, valid_prefixes)
|
||||
|
||||
|
||||
def cleanup_unreferenced_assets(session: Session) -> int:
|
||||
"""Hard-delete unhashed assets with no active cache states.
|
||||
|
||||
This is a destructive operation intended for explicit cleanup.
|
||||
Only deletes assets where hash=None and all cache states are missing.
|
||||
|
||||
Returns:
|
||||
Number of assets deleted
|
||||
"""
|
||||
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
|
||||
return delete_assets_by_ids(session, unreferenced_ids)
|
||||
58
app/assets/services/file_utils.py
Normal file
58
app/assets/services/file_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
|
||||
def get_mtime_ns(stat_result: os.stat_result) -> int:
|
||||
"""Extract mtime in nanoseconds from a stat result."""
|
||||
return getattr(
|
||||
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
|
||||
)
|
||||
|
||||
|
||||
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
|
||||
"""Get file size in bytes and mtime in nanoseconds."""
|
||||
st = os.stat(path, follow_symlinks=follow_symlinks)
|
||||
return st.st_size, get_mtime_ns(st)
|
||||
|
||||
|
||||
def verify_file_unchanged(
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
"""Check if a file is unchanged based on mtime and size.
|
||||
|
||||
Returns True if the file's mtime and size match the database values.
|
||||
Returns False if mtime_db is None or values don't match.
|
||||
|
||||
size_db=None means don't check size; 0 is a valid recorded size.
|
||||
"""
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = get_mtime_ns(stat_result)
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
if size_db is not None:
|
||||
return int(stat_result.st_size) == int(size_db)
|
||||
return True
|
||||
|
||||
|
||||
def is_visible(name: str) -> bool:
|
||||
"""Return True if a file or directory name is visible (not hidden)."""
|
||||
return not name.startswith(".")
|
||||
|
||||
|
||||
def list_files_recursively(base_dir: str) -> list[str]:
|
||||
"""Recursively list all files in a directory."""
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, subdirs, filenames in os.walk(
|
||||
base_abs, topdown=True, followlinks=False
|
||||
):
|
||||
subdirs[:] = [d for d in subdirs if is_visible(d)]
|
||||
for name in filenames:
|
||||
if not is_visible(name):
|
||||
continue
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
67
app/assets/services/hashing.py
Normal file
67
app/assets/services/hashing.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import IO
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
_blake3 = None
|
||||
|
||||
|
||||
def _get_blake3():
|
||||
global _blake3
|
||||
if _blake3 is None:
|
||||
try:
|
||||
from blake3 import blake3 as _b3
|
||||
_blake3 = _b3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"blake3 is required for asset hashing. Install with: pip install blake3"
|
||||
)
|
||||
return _blake3
|
||||
|
||||
|
||||
def compute_blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def compute_blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = _get_blake3()()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
388
app/assets/services/ingest.py
Normal file
388
app/assets/services/ingest.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.models import Asset, AssetInfo, Tag
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_asset_info,
|
||||
fetch_asset_info_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_asset_tags,
|
||||
get_or_create_asset_info,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_asset_info_metadata,
|
||||
set_asset_info_tags,
|
||||
update_asset_info_timestamps,
|
||||
upsert_asset,
|
||||
upsert_cache_state,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_filename_for_asset,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
IngestResult,
|
||||
RegisterAssetResult,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_info_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def _ingest_file_from_path(
|
||||
abs_path: str,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> IngestResult:
|
||||
locator = os.path.abspath(abs_path)
|
||||
|
||||
asset_created = False
|
||||
asset_updated = False
|
||||
state_created = False
|
||||
state_updated = False
|
||||
asset_info_id: str | None = None
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
state_created, state_updated = upsert_cache_state(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
file_path=locator,
|
||||
mtime_ns=mtime_ns,
|
||||
)
|
||||
|
||||
if info_name:
|
||||
info, info_created = get_or_create_asset_info(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
if info_created:
|
||||
asset_info_id = info.id
|
||||
else:
|
||||
update_asset_info_timestamps(
|
||||
session, asset_info=info, preview_id=preview_id
|
||||
)
|
||||
asset_info_id = info.id
|
||||
|
||||
norm = normalize_tags(list(tags))
|
||||
if norm and asset_info_id:
|
||||
if require_existing_tags:
|
||||
_validate_tags_exist(session, norm)
|
||||
add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=norm,
|
||||
origin=tag_origin,
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
|
||||
if asset_info_id:
|
||||
_update_metadata_with_filename(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
asset_id=asset.id,
|
||||
info=info,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return IngestResult(
|
||||
asset_created=asset_created,
|
||||
asset_updated=asset_updated,
|
||||
state_created=state_created,
|
||||
state_updated=state_updated,
|
||||
asset_info_id=asset_info_id,
|
||||
)
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> RegisterAssetResult:
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"No asset with hash {asset_hash}")
|
||||
|
||||
info, info_created = get_or_create_asset_info(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
preview_id=None,
|
||||
)
|
||||
|
||||
if not info_created:
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = RegisterAssetResult(
|
||||
info=extract_info_data(info),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=False,
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = compute_filename_for_asset(session, asset.id)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta:
|
||||
set_asset_info_metadata(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.refresh(info)
|
||||
result = RegisterAssetResult(
|
||||
info=extract_info_data(info),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
existing_tag_names = set(
|
||||
name
|
||||
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
|
||||
)
|
||||
missing = [t for t in tags if t not in existing_tag_names]
|
||||
if missing:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def _update_metadata_with_filename(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
asset_id: str,
|
||||
info: AssetInfo,
|
||||
user_metadata: UserMetadata,
|
||||
) -> None:
|
||||
computed_filename = compute_filename_for_asset(session, asset_id)
|
||||
|
||||
current_meta = info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
set_asset_info_metadata(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
return n if n else fallback
|
||||
|
||||
|
||||
class HashMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DependencyMissingError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def upload_from_temp_path(
|
||||
temp_path: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_hash: str | None = None,
|
||||
) -> UploadResult:
|
||||
try:
|
||||
digest = hashing.compute_blake3_hash(temp_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_hash and asset_hash != expected_hash.strip().lower():
|
||||
raise HashMismatchError("Uploaded file hash does not match provided hash.")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _sanitize_filename(name or client_filename, fallback=digest)
|
||||
result = _register_existing_asset(
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
return UploadResult(
|
||||
info=result.info,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
validate_path_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = ingest_result.asset_info_id
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(
|
||||
session, asset_info_id=info_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
|
||||
return UploadResult(
|
||||
info=extract_info_data(info),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def create_from_hash(
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> UploadResult | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return UploadResult(
|
||||
info=result.info,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
338
app/assets/services/metadata_extract.py
Normal file
338
app/assets/services/metadata_extract.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Metadata extraction for asset scanning.
|
||||
|
||||
Tier 1: Filesystem metadata (zero parsing)
|
||||
Tier 2: Safetensors header metadata (fast JSON read only)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# Supported safetensors extensions
|
||||
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
||||
|
||||
# Maximum safetensors header size to read (8MB)
|
||||
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
|
||||
|
||||
def _register_custom_mime_types():
|
||||
"""Register custom MIME types for model and config files.
|
||||
|
||||
Called before each use because mimetypes.init() in server.py resets the database.
|
||||
Uses a quick check to avoid redundant registrations.
|
||||
"""
|
||||
# Quick check if already registered (avoids redundant add_type calls)
|
||||
test_result, _ = mimetypes.guess_type("test.safetensors")
|
||||
if test_result == "application/safetensors":
|
||||
return
|
||||
|
||||
mimetypes.add_type("application/safetensors", ".safetensors")
|
||||
mimetypes.add_type("application/safetensors", ".sft")
|
||||
mimetypes.add_type("application/pytorch", ".pt")
|
||||
mimetypes.add_type("application/pytorch", ".pth")
|
||||
mimetypes.add_type("application/pickle", ".ckpt")
|
||||
mimetypes.add_type("application/pickle", ".pkl")
|
||||
mimetypes.add_type("application/gguf", ".gguf")
|
||||
mimetypes.add_type("application/yaml", ".yaml")
|
||||
mimetypes.add_type("application/yaml", ".yml")
|
||||
|
||||
|
||||
# Register custom types at module load
|
||||
_register_custom_mime_types()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedMetadata:
|
||||
"""Metadata extracted from a file during scanning."""
|
||||
|
||||
# Tier 1: Filesystem (always available)
|
||||
filename: str = ""
|
||||
content_length: int = 0
|
||||
content_type: str | None = None
|
||||
format: str = "" # file extension without dot
|
||||
|
||||
# Tier 2: Safetensors header (if available)
|
||||
base_model: str | None = None
|
||||
trained_words: list[str] | None = None
|
||||
air: str | None = None # CivitAI AIR identifier
|
||||
has_preview_images: bool = False
|
||||
|
||||
# Source provenance (populated if embedded in safetensors)
|
||||
source_url: str | None = None
|
||||
source_arn: str | None = None
|
||||
repo_url: str | None = None
|
||||
preview_url: str | None = None
|
||||
source_hash: str | None = None
|
||||
|
||||
# HuggingFace specific
|
||||
repo_id: str | None = None
|
||||
revision: str | None = None
|
||||
filepath: str | None = None
|
||||
resolve_url: str | None = None
|
||||
|
||||
def to_user_metadata(self) -> dict[str, Any]:
|
||||
"""Convert to user_metadata dict for AssetInfo.user_metadata JSON field."""
|
||||
data: dict[str, Any] = {
|
||||
"filename": self.filename,
|
||||
"content_length": self.content_length,
|
||||
"format": self.format,
|
||||
}
|
||||
if self.content_type:
|
||||
data["content_type"] = self.content_type
|
||||
|
||||
# Tier 2 fields
|
||||
if self.base_model:
|
||||
data["base_model"] = self.base_model
|
||||
if self.trained_words:
|
||||
data["trained_words"] = self.trained_words
|
||||
if self.air:
|
||||
data["air"] = self.air
|
||||
if self.has_preview_images:
|
||||
data["has_preview_images"] = True
|
||||
|
||||
# Source provenance
|
||||
if self.source_url:
|
||||
data["source_url"] = self.source_url
|
||||
if self.source_arn:
|
||||
data["source_arn"] = self.source_arn
|
||||
if self.repo_url:
|
||||
data["repo_url"] = self.repo_url
|
||||
if self.preview_url:
|
||||
data["preview_url"] = self.preview_url
|
||||
if self.source_hash:
|
||||
data["source_hash"] = self.source_hash
|
||||
|
||||
# HuggingFace
|
||||
if self.repo_id:
|
||||
data["repo_id"] = self.repo_id
|
||||
if self.revision:
|
||||
data["revision"] = self.revision
|
||||
if self.filepath:
|
||||
data["filepath"] = self.filepath
|
||||
if self.resolve_url:
|
||||
data["resolve_url"] = self.resolve_url
|
||||
|
||||
return data
|
||||
|
||||
def to_meta_rows(self, asset_info_id: str) -> list[dict]:
|
||||
"""Convert to asset_info_meta rows for typed/indexed querying."""
|
||||
rows: list[dict] = []
|
||||
|
||||
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
|
||||
if val:
|
||||
rows.append({
|
||||
"asset_info_id": asset_info_id,
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": val[:2048] if len(val) > 2048 else val,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_num(key: str, val: int | float | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_info_id": asset_info_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": val,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_bool(key: str, val: bool | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_info_id": asset_info_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": val,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
# Tier 1
|
||||
add_str("filename", self.filename)
|
||||
add_num("content_length", self.content_length)
|
||||
add_str("content_type", self.content_type)
|
||||
add_str("format", self.format)
|
||||
|
||||
# Tier 2
|
||||
add_str("base_model", self.base_model)
|
||||
add_str("air", self.air)
|
||||
add_bool("has_preview_images", self.has_preview_images if self.has_preview_images else None)
|
||||
|
||||
# trained_words as multiple rows with ordinals
|
||||
if self.trained_words:
|
||||
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
|
||||
add_str("trained_words", word, ordinal=i)
|
||||
|
||||
# Source provenance
|
||||
add_str("source_url", self.source_url)
|
||||
add_str("source_arn", self.source_arn)
|
||||
add_str("repo_url", self.repo_url)
|
||||
add_str("preview_url", self.preview_url)
|
||||
add_str("source_hash", self.source_hash)
|
||||
|
||||
# HuggingFace
|
||||
add_str("repo_id", self.repo_id)
|
||||
add_str("revision", self.revision)
|
||||
add_str("filepath", self.filepath)
|
||||
add_str("resolve_url", self.resolve_url)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _read_safetensors_header(path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE) -> dict[str, Any] | None:
|
||||
"""Read only the JSON header from a safetensors file.
|
||||
|
||||
This is very fast - reads 8 bytes for header length, then the JSON header.
|
||||
No tensor data is loaded.
|
||||
|
||||
Args:
|
||||
path: Absolute path to safetensors file
|
||||
max_size: Maximum header size to read (default 8MB)
|
||||
|
||||
Returns:
|
||||
Parsed header dict or None if failed
|
||||
"""
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
header_bytes = f.read(8)
|
||||
if len(header_bytes) < 8:
|
||||
return None
|
||||
length_of_header = struct.unpack("<Q", header_bytes)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
header_data = f.read(length_of_header)
|
||||
if len(header_data) < length_of_header:
|
||||
return None
|
||||
return json.loads(header_data.decode("utf-8"))
|
||||
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_safetensors_metadata(header: dict[str, Any], meta: ExtractedMetadata) -> None:
|
||||
"""Extract metadata from safetensors header __metadata__ section.
|
||||
|
||||
Modifies meta in-place.
|
||||
"""
|
||||
st_meta = header.get("__metadata__", {})
|
||||
if not isinstance(st_meta, dict):
|
||||
return
|
||||
|
||||
# Common model metadata
|
||||
meta.base_model = st_meta.get("ss_base_model_version") or st_meta.get("modelspec.base_model") or st_meta.get("base_model")
|
||||
|
||||
# Trained words / trigger words
|
||||
trained_words = st_meta.get("ss_tag_frequency")
|
||||
if trained_words and isinstance(trained_words, str):
|
||||
try:
|
||||
tag_freq = json.loads(trained_words)
|
||||
# Extract unique tags from all datasets
|
||||
all_tags: set[str] = set()
|
||||
for dataset_tags in tag_freq.values():
|
||||
if isinstance(dataset_tags, dict):
|
||||
all_tags.update(dataset_tags.keys())
|
||||
if all_tags:
|
||||
meta.trained_words = sorted(all_tags)[:100]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Direct trained_words field (some formats)
|
||||
if not meta.trained_words:
|
||||
tw = st_meta.get("trained_words")
|
||||
if isinstance(tw, str):
|
||||
try:
|
||||
meta.trained_words = json.loads(tw)
|
||||
except json.JSONDecodeError:
|
||||
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
||||
elif isinstance(tw, list):
|
||||
meta.trained_words = tw
|
||||
|
||||
# CivitAI AIR
|
||||
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
|
||||
|
||||
# Preview images (ssmd_cover_images)
|
||||
cover_images = st_meta.get("ssmd_cover_images")
|
||||
if cover_images:
|
||||
meta.has_preview_images = True
|
||||
|
||||
# Source provenance fields
|
||||
meta.source_url = st_meta.get("source_url")
|
||||
meta.source_arn = st_meta.get("source_arn")
|
||||
meta.repo_url = st_meta.get("repo_url")
|
||||
meta.preview_url = st_meta.get("preview_url")
|
||||
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
|
||||
|
||||
# HuggingFace fields
|
||||
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
|
||||
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
|
||||
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
|
||||
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
|
||||
|
||||
|
||||
def extract_file_metadata(
|
||||
abs_path: str,
|
||||
stat_result: os.stat_result | None = None,
|
||||
enable_safetensors: bool = True,
|
||||
relative_filename: str | None = None,
|
||||
) -> ExtractedMetadata:
|
||||
"""Extract metadata from a file using tier 1 and optionally tier 2 methods.
|
||||
|
||||
Tier 1 (always): Filesystem metadata from path and stat
|
||||
Tier 2 (optional): Safetensors header parsing if applicable
|
||||
|
||||
Args:
|
||||
abs_path: Absolute path to the file
|
||||
stat_result: Optional pre-fetched stat result (saves a syscall)
|
||||
enable_safetensors: Whether to parse safetensors headers (tier 2)
|
||||
relative_filename: Optional relative filename to use instead of basename
|
||||
(e.g., "flux/123/model.safetensors" for model paths)
|
||||
|
||||
Returns:
|
||||
ExtractedMetadata with all available fields populated
|
||||
"""
|
||||
meta = ExtractedMetadata()
|
||||
|
||||
# Tier 1: Filesystem metadata
|
||||
# Use relative_filename if provided (for backward compatibility with existing behavior)
|
||||
meta.filename = relative_filename if relative_filename else os.path.basename(abs_path)
|
||||
_, ext = os.path.splitext(abs_path)
|
||||
meta.format = ext.lstrip(".").lower() if ext else ""
|
||||
|
||||
# MIME type guess (re-register in case mimetypes.init() was called elsewhere)
|
||||
_register_custom_mime_types()
|
||||
mime_type, _ = mimetypes.guess_type(abs_path)
|
||||
meta.content_type = mime_type
|
||||
if mime_type is None:
|
||||
pass
|
||||
|
||||
# Size from stat
|
||||
if stat_result is None:
|
||||
try:
|
||||
stat_result = os.stat(abs_path, follow_symlinks=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if stat_result:
|
||||
meta.content_length = stat_result.st_size
|
||||
|
||||
# Tier 2: Safetensors header (if applicable and enabled)
|
||||
if enable_safetensors and ext.lower() in SAFETENSORS_EXTENSIONS:
|
||||
header = _read_safetensors_header(abs_path)
|
||||
if header:
|
||||
try:
|
||||
_extract_safetensors_metadata(header, meta)
|
||||
except Exception as e:
|
||||
logging.debug("Failed to extract safetensors metadata from %s: %s", abs_path, e)
|
||||
|
||||
return meta
|
||||
184
app/assets/services/path_utils.py
Normal file
184
app/assets/services/path_utils.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = (
|
||||
values[0],
|
||||
values[1],
|
||||
) # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory()
|
||||
if root == "input"
|
||||
else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_asset_category_and_relative_path(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _check_is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _compute_relative(child: str, parent: str) -> str:
|
||||
return os.path.relpath(
|
||||
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
|
||||
)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _check_is_within(fp_abs, input_base):
|
||||
return "input", _compute_relative(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _check_is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
def compute_filename_for_asset(session, asset_id: str) -> str | None:
|
||||
"""Compute the relative filename for an asset from its best live cache state path."""
|
||||
from app.assets.database.queries import list_cache_states_by_asset_id
|
||||
from app.assets.helpers import select_best_live_path
|
||||
|
||||
primary_path = select_best_live_path(
|
||||
list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
)
|
||||
return compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_asset_category_and_relative_path`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_asset_category_and_relative_path(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
126
app/assets/services/schemas.py
Normal file
126
app/assets/services/schemas.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo
|
||||
|
||||
UserMetadata = dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetData:
|
||||
hash: str
|
||||
size_bytes: int | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetInfoData:
|
||||
id: str
|
||||
name: str
|
||||
user_metadata: UserMetadata
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetDetailResult:
|
||||
info: AssetInfoData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterAssetResult:
|
||||
info: AssetInfoData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IngestResult:
|
||||
asset_created: bool
|
||||
asset_updated: bool
|
||||
state_created: bool
|
||||
state_updated: bool
|
||||
asset_info_id: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddTagsResult:
|
||||
added: list[str]
|
||||
already_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoveTagsResult:
|
||||
removed: list[str]
|
||||
not_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetTagsResult:
|
||||
added: list[str]
|
||||
removed: list[str]
|
||||
total: list[str]
|
||||
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetSummaryData:
|
||||
info: AssetInfoData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DownloadResolutionResult:
|
||||
abs_path: str
|
||||
content_type: str
|
||||
download_name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
info: AssetInfoData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created_new: bool
|
||||
|
||||
|
||||
def extract_info_data(info: AssetInfo) -> AssetInfoData:
|
||||
return AssetInfoData(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
user_metadata=info.user_metadata,
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def extract_asset_data(asset: Asset | None) -> AssetData | None:
|
||||
if asset is None:
|
||||
return None
|
||||
return AssetData(
|
||||
hash=asset.hash,
|
||||
size_bytes=asset.size_bytes,
|
||||
mime_type=asset.mime_type,
|
||||
)
|
||||
89
app/assets/services/tagging.py
Normal file
89
app/assets/services/tagging.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_asset_info,
|
||||
get_asset_info_by_id,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_asset_info,
|
||||
)
|
||||
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def apply_tags(
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AddTagsResult:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return AddTagsResult(
|
||||
added=data["added"],
|
||||
already_present=data["already_present"],
|
||||
total_tags=data["total_tags"],
|
||||
)
|
||||
|
||||
|
||||
def remove_tags(
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> RemoveTagsResult:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return RemoveTagsResult(
|
||||
removed=data["removed"],
|
||||
not_present=data["not_present"],
|
||||
total_tags=data["total_tags"],
|
||||
)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[TagUsage], int]:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
@@ -14,7 +14,7 @@ try:
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
@@ -75,6 +75,13 @@ def init_db():
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
|
||||
# Enable foreign key enforcement for SQLite
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
|
||||
10
main.py
10
main.py
@@ -7,7 +7,7 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.scanner import seed_assets
|
||||
from app.assets.seeder import asset_seeder
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
@@ -354,7 +354,8 @@ def setup_database():
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if not args.disable_assets_autoscan:
|
||||
seed_assets(["models"], enable_logging=True)
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
|
||||
logging.info("Background asset scan initiated for models, input, output")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
@@ -438,5 +439,6 @@ if __name__ == "__main__":
|
||||
event_loop.run_until_complete(x)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
||||
cleanup_temp()
|
||||
finally:
|
||||
asset_seeder.shutdown()
|
||||
cleanup_temp()
|
||||
|
||||
@@ -259,13 +259,3 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||
"""Force a fast sync/seed pass by calling the seed endpoint."""
|
||||
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
|
||||
15
tests-unit/assets_test/helpers.py
Normal file
15
tests-unit/assets_test/helpers.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Helper functions for assets integration tests."""
|
||||
import requests
|
||||
|
||||
|
||||
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||
"""Force a synchronous sync/seed pass by calling the seed endpoint with wait=true."""
|
||||
session.post(
|
||||
base_url + "/api/assets/seed?wait=true",
|
||||
json={"roots": ["models", "input", "output"]},
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
20
tests-unit/assets_test/queries/conftest.py
Normal file
20
tests-unit/assets_test/queries/conftest.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
"""In-memory SQLite session for fast unit tests."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
with Session(engine) as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def autoclean_unit_test_assets():
|
||||
"""Override parent autouse fixture - query tests don't need server cleanup."""
|
||||
yield
|
||||
142
tests-unit/assets_test/queries/test_asset.py
Normal file
142
tests-unit/assets_test/queries/test_asset.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
get_asset_by_hash,
|
||||
upsert_asset,
|
||||
bulk_insert_assets,
|
||||
)
|
||||
|
||||
|
||||
class TestAssetExistsByHash:
|
||||
@pytest.mark.parametrize(
|
||||
"setup_hash,query_hash,expected",
|
||||
[
|
||||
(None, "nonexistent", False), # No asset exists
|
||||
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
|
||||
(None, "", False), # Null hash in DB doesn't match empty string
|
||||
],
|
||||
ids=["nonexistent", "existing", "null_hash_no_match"],
|
||||
)
|
||||
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
|
||||
if setup_hash is not None or setup_hash is None:
|
||||
# Create asset with given hash (including None for null hash test)
|
||||
if setup_hash is not None or query_hash == "":
|
||||
asset = Asset(hash=setup_hash, size_bytes=100)
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
|
||||
|
||||
|
||||
class TestGetAssetByHash:
|
||||
@pytest.mark.parametrize(
|
||||
"setup_hash,query_hash,should_find",
|
||||
[
|
||||
(None, "nonexistent", False),
|
||||
("blake3:def456", "blake3:def456", True),
|
||||
],
|
||||
ids=["nonexistent", "existing"],
|
||||
)
|
||||
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
|
||||
if setup_hash is not None:
|
||||
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
result = get_asset_by_hash(session, asset_hash=query_hash)
|
||||
if should_find:
|
||||
assert result is not None
|
||||
assert result.size_bytes == 200
|
||||
assert result.mime_type == "image/png"
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUpsertAsset:
|
||||
@pytest.mark.parametrize(
|
||||
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
|
||||
[
|
||||
# New asset creation
|
||||
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
|
||||
# Existing asset, same values - no update
|
||||
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
|
||||
# Existing asset with size 0, update with new values
|
||||
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
|
||||
# Existing asset, second call with size 0 - no update
|
||||
(1000, None, 0, None, False, False, 1000, None),
|
||||
],
|
||||
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
|
||||
)
|
||||
def test_upsert_scenarios(
|
||||
self,
|
||||
session: Session,
|
||||
first_size,
|
||||
first_mime,
|
||||
second_size,
|
||||
second_mime,
|
||||
expect_created,
|
||||
expect_updated,
|
||||
final_size,
|
||||
final_mime,
|
||||
):
|
||||
asset_hash = f"blake3:test_{first_size}_{second_size}"
|
||||
|
||||
# First upsert (if first_size is not None, we're testing the second call)
|
||||
if first_size is not None:
|
||||
upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=first_size,
|
||||
mime_type=first_mime,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# The upsert call we're testing
|
||||
asset, created, updated = upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=second_size,
|
||||
mime_type=second_mime,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is expect_created
|
||||
assert updated is expect_updated
|
||||
assert asset.size_bytes == final_size
|
||||
assert asset.mime_type == final_mime
|
||||
|
||||
|
||||
class TestBulkInsertAssets:
|
||||
def test_inserts_multiple_assets(self, session: Session):
|
||||
rows = [
|
||||
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": None},
|
||||
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": None},
|
||||
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": None},
|
||||
]
|
||||
bulk_insert_assets(session, rows)
|
||||
session.commit()
|
||||
|
||||
assets = session.query(Asset).all()
|
||||
assert len(assets) == 3
|
||||
hashes = {a.hash for a in assets}
|
||||
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
bulk_insert_assets(session, [])
|
||||
session.commit()
|
||||
assert session.query(Asset).count() == 0
|
||||
|
||||
def test_handles_large_batch(self, session: Session):
|
||||
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
|
||||
rows = [
|
||||
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": None}
|
||||
for i in range(200)
|
||||
]
|
||||
bulk_insert_assets(session, rows)
|
||||
session.commit()
|
||||
|
||||
assert session.query(Asset).count() == 200
|
||||
511
tests-unit/assets_test/queries/test_asset_info.py
Normal file
511
tests-unit/assets_test/queries/test_asset_info.py
Normal file
@@ -0,0 +1,511 @@
|
||||
import time
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
|
||||
from app.assets.database.queries import (
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_info_by_id,
|
||||
insert_asset_info,
|
||||
get_or_create_asset_info,
|
||||
update_asset_info_timestamps,
|
||||
list_asset_infos_page,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
update_asset_info_access_time,
|
||||
set_asset_info_metadata,
|
||||
delete_asset_info_by_id,
|
||||
set_asset_info_preview,
|
||||
bulk_insert_asset_infos_ignore_conflicts,
|
||||
get_asset_info_ids_by_ids,
|
||||
ensure_tags_exist,
|
||||
add_tags_to_asset_info,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_asset_info(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str = "test",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
now = get_utc_now()
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
return info
|
||||
|
||||
|
||||
class TestAssetInfoExistsForAssetId:
|
||||
def test_returns_false_when_no_info(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is False
|
||||
|
||||
def test_returns_true_when_info_exists(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset)
|
||||
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is True
|
||||
|
||||
|
||||
class TestGetAssetInfoById:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
assert get_asset_info_by_id(session, asset_info_id="nonexistent") is None
|
||||
|
||||
def test_returns_info(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset, name="myfile.txt")
|
||||
|
||||
result = get_asset_info_by_id(session, asset_info_id=info.id)
|
||||
assert result is not None
|
||||
assert result.name == "myfile.txt"
|
||||
|
||||
|
||||
class TestListAssetInfosPage:
|
||||
def test_empty_db(self, session: Session):
|
||||
infos, tag_map, total = list_asset_infos_page(session)
|
||||
assert infos == []
|
||||
assert tag_map == {}
|
||||
assert total == 0
|
||||
|
||||
def test_returns_infos_with_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset, name="test.bin")
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
|
||||
session.commit()
|
||||
|
||||
infos, tag_map, total = list_asset_infos_page(session)
|
||||
assert len(infos) == 1
|
||||
assert infos[0].id == info.id
|
||||
assert set(tag_map[info.id]) == {"alpha", "beta"}
|
||||
assert total == 1
|
||||
|
||||
def test_name_contains_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, name="model_v1.safetensors")
|
||||
_make_asset_info(session, asset, name="config.json")
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(session, name_contains="model")
|
||||
assert total == 1
|
||||
assert infos[0].name == "model_v1.safetensors"
|
||||
|
||||
def test_owner_visibility(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, name="public", owner_id="")
|
||||
_make_asset_info(session, asset, name="private", owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
# Empty owner sees only public
|
||||
infos, _, total = list_asset_infos_page(session, owner_id="")
|
||||
assert total == 1
|
||||
assert infos[0].name == "public"
|
||||
|
||||
# Owner sees both
|
||||
infos, _, total = list_asset_infos_page(session, owner_id="user1")
|
||||
assert total == 2
|
||||
|
||||
def test_include_tags_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info1 = _make_asset_info(session, asset, name="tagged")
|
||||
_make_asset_info(session, asset, name="untagged")
|
||||
ensure_tags_exist(session, ["wanted"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info1.id, tags=["wanted"])
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(session, include_tags=["wanted"])
|
||||
assert total == 1
|
||||
assert infos[0].name == "tagged"
|
||||
|
||||
def test_exclude_tags_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, name="keep")
|
||||
info_exclude = _make_asset_info(session, asset, name="exclude")
|
||||
ensure_tags_exist(session, ["bad"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info_exclude.id, tags=["bad"])
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(session, exclude_tags=["bad"])
|
||||
assert total == 1
|
||||
assert infos[0].name == "keep"
|
||||
|
||||
def test_sorting(self, session: Session):
|
||||
asset = _make_asset(session, "hash1", size=100)
|
||||
asset2 = _make_asset(session, "hash2", size=500)
|
||||
_make_asset_info(session, asset, name="small")
|
||||
_make_asset_info(session, asset2, name="large")
|
||||
session.commit()
|
||||
|
||||
infos, _, _ = list_asset_infos_page(session, sort="size", order="desc")
|
||||
assert infos[0].name == "large"
|
||||
|
||||
infos, _, _ = list_asset_infos_page(session, sort="name", order="asc")
|
||||
assert infos[0].name == "large"
|
||||
|
||||
|
||||
class TestFetchAssetInfoAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_asset_info_asset_and_tags(session, "nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_returns_tuple(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset, name="test.bin")
|
||||
ensure_tags_exist(session, ["tag1"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["tag1"])
|
||||
session.commit()
|
||||
|
||||
result = fetch_asset_info_asset_and_tags(session, info.id)
|
||||
assert result is not None
|
||||
ret_info, ret_asset, ret_tags = result
|
||||
assert ret_info.id == info.id
|
||||
assert ret_asset.id == asset.id
|
||||
assert ret_tags == ["tag1"]
|
||||
|
||||
|
||||
class TestFetchAssetInfoAndAsset:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_asset_info_and_asset(session, asset_info_id="nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_returns_tuple(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
result = fetch_asset_info_and_asset(session, asset_info_id=info.id)
|
||||
assert result is not None
|
||||
ret_info, ret_asset = result
|
||||
assert ret_info.id == info.id
|
||||
assert ret_asset.id == asset.id
|
||||
|
||||
|
||||
class TestUpdateAssetInfoAccessTime:
|
||||
def test_updates_last_access_time(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
original_time = info.last_access_time
|
||||
session.commit()
|
||||
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
|
||||
update_asset_info_access_time(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(info)
|
||||
assert info.last_access_time > original_time
|
||||
|
||||
|
||||
class TestDeleteAssetInfoById:
|
||||
def test_deletes_existing(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="")
|
||||
assert result is True
|
||||
assert get_asset_info_by_id(session, asset_info_id=info.id) is None
|
||||
|
||||
def test_returns_false_for_nonexistent(self, session: Session):
|
||||
result = delete_asset_info_by_id(session, asset_info_id="nonexistent", owner_id="")
|
||||
assert result is False
|
||||
|
||||
def test_respects_owner_visibility(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset, owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="user2")
|
||||
assert result is False
|
||||
assert get_asset_info_by_id(session, asset_info_id=info.id) is not None
|
||||
|
||||
|
||||
class TestSetAssetInfoPreview:
|
||||
def test_sets_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=preview_asset.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(info)
|
||||
assert info.preview_id == preview_asset.id
|
||||
|
||||
def test_clears_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
info = _make_asset_info(session, asset)
|
||||
info.preview_id = preview_asset.id
|
||||
session.commit()
|
||||
|
||||
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=None)
|
||||
session.commit()
|
||||
|
||||
session.refresh(info)
|
||||
assert info.preview_id is None
|
||||
|
||||
def test_raises_for_nonexistent_info(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_asset_info_preview(session, asset_info_id="nonexistent", preview_asset_id=None)
|
||||
|
||||
def test_raises_for_nonexistent_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(ValueError, match="Preview Asset"):
|
||||
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id="nonexistent")
|
||||
|
||||
|
||||
class TestInsertAssetInfo:
|
||||
def test_creates_new_info(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = insert_asset_info(
|
||||
session, asset_id=asset.id, owner_id="user1", name="test.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert info is not None
|
||||
assert info.name == "test.bin"
|
||||
assert info.owner_id == "user1"
|
||||
|
||||
def test_returns_none_on_conflict(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
insert_asset_info(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
|
||||
session.commit()
|
||||
|
||||
# Attempt duplicate with same (asset_id, owner_id, name)
|
||||
result = insert_asset_info(
|
||||
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetOrCreateAssetInfo:
|
||||
def test_creates_new_info(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info, created = get_or_create_asset_info(
|
||||
session, asset_id=asset.id, owner_id="user1", name="new.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is True
|
||||
assert info.name == "new.bin"
|
||||
|
||||
def test_returns_existing_info(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info1, created1 = get_or_create_asset_info(
|
||||
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
info2, created2 = get_or_create_asset_info(
|
||||
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created1 is True
|
||||
assert created2 is False
|
||||
assert info1.id == info2.id
|
||||
|
||||
|
||||
class TestUpdateAssetInfoTimestamps:
|
||||
def test_updates_timestamps(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
original_updated_at = info.updated_at
|
||||
session.commit()
|
||||
|
||||
time.sleep(0.01)
|
||||
update_asset_info_timestamps(session, info)
|
||||
session.commit()
|
||||
|
||||
session.refresh(info)
|
||||
assert info.updated_at > original_updated_at
|
||||
|
||||
def test_updates_preview_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
update_asset_info_timestamps(session, info, preview_id=preview_asset.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(info)
|
||||
assert info.preview_id == preview_asset.id
|
||||
|
||||
|
||||
class TestSetAssetInfoMetadata:
|
||||
def test_sets_metadata(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
set_asset_info_metadata(
|
||||
session, asset_info_id=info.id, user_metadata={"key": "value"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
session.refresh(info)
|
||||
assert info.user_metadata == {"key": "value"}
|
||||
# Check metadata table
|
||||
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
|
||||
assert len(meta) == 1
|
||||
assert meta[0].key == "key"
|
||||
assert meta[0].val_str == "value"
|
||||
|
||||
def test_replaces_existing_metadata(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
set_asset_info_metadata(
|
||||
session, asset_info_id=info.id, user_metadata={"old": "data"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
set_asset_info_metadata(
|
||||
session, asset_info_id=info.id, user_metadata={"new": "data"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
|
||||
assert len(meta) == 1
|
||||
assert meta[0].key == "new"
|
||||
|
||||
def test_clears_metadata_with_empty_dict(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
set_asset_info_metadata(
|
||||
session, asset_info_id=info.id, user_metadata={"key": "value"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
set_asset_info_metadata(
|
||||
session, asset_info_id=info.id, user_metadata={}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
session.refresh(info)
|
||||
assert info.user_metadata == {}
|
||||
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
|
||||
assert len(meta) == 0
|
||||
|
||||
def test_raises_for_nonexistent(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_asset_info_metadata(
|
||||
session, asset_info_id="nonexistent", user_metadata={"key": "value"}
|
||||
)
|
||||
|
||||
|
||||
class TestBulkInsertAssetInfosIgnoreConflicts:
|
||||
def test_inserts_multiple_infos(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "bulk1.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "bulk2.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
]
|
||||
bulk_insert_asset_infos_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
infos = session.query(AssetInfo).all()
|
||||
assert len(infos) == 2
|
||||
|
||||
def test_ignores_conflicts(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, name="existing.bin", owner_id="")
|
||||
session.commit()
|
||||
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "existing.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "new.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
]
|
||||
bulk_insert_asset_infos_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
infos = session.query(AssetInfo).all()
|
||||
assert len(infos) == 2 # existing + new, not 3
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
bulk_insert_asset_infos_ignore_conflicts(session, [])
|
||||
assert session.query(AssetInfo).count() == 0
|
||||
|
||||
|
||||
class TestGetAssetInfoIdsByIds:
|
||||
def test_returns_existing_ids(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info1 = _make_asset_info(session, asset, name="a.bin")
|
||||
info2 = _make_asset_info(session, asset, name="b.bin")
|
||||
session.commit()
|
||||
|
||||
found = get_asset_info_ids_by_ids(session, [info1.id, info2.id, "nonexistent"])
|
||||
|
||||
assert found == {info1.id, info2.id}
|
||||
|
||||
def test_empty_list_returns_empty(self, session: Session):
|
||||
found = get_asset_info_ids_by_ids(session, [])
|
||||
assert found == set()
|
||||
468
tests-unit/assets_test/queries/test_cache_state.py
Normal file
468
tests-unit/assets_test/queries/test_cache_state.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""Tests for cache_state query functions."""
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
from app.assets.database.queries import (
|
||||
list_cache_states_by_asset_id,
|
||||
upsert_cache_state,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
delete_assets_by_ids,
|
||||
get_cache_states_for_prefixes,
|
||||
bulk_update_needs_verify,
|
||||
delete_cache_states_by_ids,
|
||||
delete_orphaned_seed_asset,
|
||||
bulk_insert_cache_states_ignore_conflicts,
|
||||
get_cache_states_by_paths_and_asset_ids,
|
||||
mark_cache_states_missing_outside_prefixes,
|
||||
restore_cache_states_by_paths,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path, get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=size)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_cache_state(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
file_path: str,
|
||||
mtime_ns: int | None = None,
|
||||
needs_verify: bool = False,
|
||||
) -> AssetCacheState:
|
||||
state = AssetCacheState(
|
||||
asset_id=asset.id,
|
||||
file_path=file_path,
|
||||
mtime_ns=mtime_ns,
|
||||
needs_verify=needs_verify,
|
||||
)
|
||||
session.add(state)
|
||||
session.flush()
|
||||
return state
|
||||
|
||||
|
||||
class TestListCacheStatesByAssetId:
|
||||
def test_returns_empty_for_no_states(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
assert list(states) == []
|
||||
|
||||
def test_returns_states_for_asset(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_cache_state(session, asset, "/path/a.bin")
|
||||
_make_cache_state(session, asset, "/path/b.bin")
|
||||
session.commit()
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
paths = [s.file_path for s in states]
|
||||
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
|
||||
|
||||
def test_does_not_return_other_assets_states(self, session: Session):
|
||||
asset1 = _make_asset(session, "hash1")
|
||||
asset2 = _make_asset(session, "hash2")
|
||||
_make_cache_state(session, asset1, "/path/asset1.bin")
|
||||
_make_cache_state(session, asset2, "/path/asset2.bin")
|
||||
session.commit()
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset1.id)
|
||||
paths = [s.file_path for s in states]
|
||||
assert paths == ["/path/asset1.bin"]
|
||||
|
||||
|
||||
class TestSelectBestLivePath:
|
||||
def test_returns_empty_for_empty_list(self):
|
||||
result = select_best_live_path([])
|
||||
assert result == ""
|
||||
|
||||
def test_returns_empty_when_no_files_exist(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
state = _make_cache_state(session, asset, "/nonexistent/path.bin")
|
||||
session.commit()
|
||||
|
||||
result = select_best_live_path([state])
|
||||
assert result == ""
|
||||
|
||||
def test_prefers_verified_path(self, session: Session, tmp_path):
|
||||
"""needs_verify=False should be preferred."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
|
||||
verified_file = tmp_path / "verified.bin"
|
||||
verified_file.write_bytes(b"data")
|
||||
|
||||
unverified_file = tmp_path / "unverified.bin"
|
||||
unverified_file.write_bytes(b"data")
|
||||
|
||||
state_verified = _make_cache_state(
|
||||
session, asset, str(verified_file), needs_verify=False
|
||||
)
|
||||
state_unverified = _make_cache_state(
|
||||
session, asset, str(unverified_file), needs_verify=True
|
||||
)
|
||||
session.commit()
|
||||
|
||||
states = [state_unverified, state_verified]
|
||||
result = select_best_live_path(states)
|
||||
assert result == str(verified_file)
|
||||
|
||||
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
|
||||
"""If all states need verification, return first existing path."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
|
||||
existing_file = tmp_path / "exists.bin"
|
||||
existing_file.write_bytes(b"data")
|
||||
|
||||
state = _make_cache_state(session, asset, str(existing_file), needs_verify=True)
|
||||
session.commit()
|
||||
|
||||
result = select_best_live_path([state])
|
||||
assert result == str(existing_file)
|
||||
|
||||
|
||||
class TestSelectBestLivePathWithMocking:
|
||||
def test_handles_missing_file_path_attr(self):
|
||||
"""Gracefully handle states with None file_path."""
|
||||
|
||||
class MockState:
|
||||
file_path = None
|
||||
needs_verify = False
|
||||
|
||||
result = select_best_live_path([MockState()])
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestUpsertCacheState:
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
|
||||
[
|
||||
# New state creation
|
||||
(None, 12345, True, False, 12345),
|
||||
# Existing state, same mtime - no update
|
||||
(100, 100, False, False, 100),
|
||||
# Existing state, different mtime - update
|
||||
(100, 200, False, True, 200),
|
||||
],
|
||||
ids=["new_state", "existing_no_change", "existing_update_mtime"],
|
||||
)
|
||||
def test_upsert_scenarios(
|
||||
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
|
||||
|
||||
# Create initial state if needed
|
||||
if initial_mtime is not None:
|
||||
upsert_cache_state(session, asset_id=asset.id, file_path=file_path, mtime_ns=initial_mtime)
|
||||
session.commit()
|
||||
|
||||
# The upsert call we're testing
|
||||
created, updated = upsert_cache_state(
|
||||
session, asset_id=asset.id, file_path=file_path, mtime_ns=second_mtime
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is expect_created
|
||||
assert updated is expect_updated
|
||||
state = session.query(AssetCacheState).filter_by(file_path=file_path).one()
|
||||
assert state.mtime_ns == final_mtime
|
||||
|
||||
def test_upsert_restores_missing_state(self, session: Session):
|
||||
"""Upserting a cache state that was marked missing should restore it."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
file_path = "/restored/file.bin"
|
||||
|
||||
state = _make_cache_state(session, asset, file_path, mtime_ns=100)
|
||||
state.is_missing = True
|
||||
session.commit()
|
||||
|
||||
created, updated = upsert_cache_state(
|
||||
session, asset_id=asset.id, file_path=file_path, mtime_ns=100
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is False
|
||||
assert updated is True
|
||||
restored_state = session.query(AssetCacheState).filter_by(file_path=file_path).one()
|
||||
assert restored_state.is_missing is False
|
||||
|
||||
|
||||
class TestRestoreCacheStatesByPaths:
|
||||
def test_restores_missing_states(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
missing_path = "/missing/file.bin"
|
||||
active_path = "/active/file.bin"
|
||||
|
||||
missing_state = _make_cache_state(session, asset, missing_path)
|
||||
missing_state.is_missing = True
|
||||
_make_cache_state(session, asset, active_path)
|
||||
session.commit()
|
||||
|
||||
restored = restore_cache_states_by_paths(session, [missing_path])
|
||||
session.commit()
|
||||
|
||||
assert restored == 1
|
||||
state = session.query(AssetCacheState).filter_by(file_path=missing_path).one()
|
||||
assert state.is_missing is False
|
||||
|
||||
def test_empty_list_restores_nothing(self, session: Session):
|
||||
restored = restore_cache_states_by_paths(session, [])
|
||||
assert restored == 0
|
||||
|
||||
|
||||
class TestMarkCacheStatesMissingOutsidePrefixes:
|
||||
def test_marks_states_missing_outside_prefixes(self, session: Session, tmp_path):
|
||||
asset = _make_asset(session, "hash1")
|
||||
valid_dir = tmp_path / "valid"
|
||||
valid_dir.mkdir()
|
||||
invalid_dir = tmp_path / "invalid"
|
||||
invalid_dir.mkdir()
|
||||
|
||||
valid_path = str(valid_dir / "file.bin")
|
||||
invalid_path = str(invalid_dir / "file.bin")
|
||||
|
||||
_make_cache_state(session, asset, valid_path)
|
||||
_make_cache_state(session, asset, invalid_path)
|
||||
session.commit()
|
||||
|
||||
marked = mark_cache_states_missing_outside_prefixes(session, [str(valid_dir)])
|
||||
session.commit()
|
||||
|
||||
assert marked == 1
|
||||
all_states = session.query(AssetCacheState).all()
|
||||
assert len(all_states) == 2
|
||||
|
||||
valid_state = next(s for s in all_states if s.file_path == valid_path)
|
||||
invalid_state = next(s for s in all_states if s.file_path == invalid_path)
|
||||
assert valid_state.is_missing is False
|
||||
assert invalid_state.is_missing is True
|
||||
|
||||
def test_empty_prefixes_marks_nothing(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_cache_state(session, asset, "/some/path.bin")
|
||||
session.commit()
|
||||
|
||||
marked = mark_cache_states_missing_outside_prefixes(session, [])
|
||||
|
||||
assert marked == 0
|
||||
|
||||
|
||||
class TestGetUnreferencedUnhashedAssetIds:
|
||||
def test_returns_unreferenced_unhashed_assets(self, session: Session):
|
||||
# Unhashed asset (hash=None) with no cache states
|
||||
no_states = _make_asset(session, hash_val=None)
|
||||
# Unhashed asset with active cache state (not unreferenced)
|
||||
with_active_state = _make_asset(session, hash_val=None)
|
||||
_make_cache_state(session, with_active_state, "/has/state.bin")
|
||||
# Unhashed asset with only missing cache state (should be unreferenced)
|
||||
with_missing_state = _make_asset(session, hash_val=None)
|
||||
missing_state = _make_cache_state(session, with_missing_state, "/missing/state.bin")
|
||||
missing_state.is_missing = True
|
||||
# Regular asset (hash not None) - should not be returned
|
||||
_make_asset(session, hash_val="blake3:regular")
|
||||
session.commit()
|
||||
|
||||
unreferenced = get_unreferenced_unhashed_asset_ids(session)
|
||||
|
||||
assert no_states.id in unreferenced
|
||||
assert with_missing_state.id in unreferenced
|
||||
assert with_active_state.id not in unreferenced
|
||||
|
||||
|
||||
class TestDeleteAssetsByIds:
|
||||
def test_deletes_assets_and_infos(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
now = get_utc_now()
|
||||
info = AssetInfo(
|
||||
owner_id="", name="test", asset_id=asset.id,
|
||||
created_at=now, updated_at=now, last_access_time=now
|
||||
)
|
||||
session.add(info)
|
||||
session.commit()
|
||||
|
||||
deleted = delete_assets_by_ids(session, [asset.id])
|
||||
session.commit()
|
||||
|
||||
assert deleted == 1
|
||||
assert session.query(Asset).count() == 0
|
||||
assert session.query(AssetInfo).count() == 0
|
||||
|
||||
def test_empty_list_deletes_nothing(self, session: Session):
|
||||
_make_asset(session, "hash1")
|
||||
session.commit()
|
||||
|
||||
deleted = delete_assets_by_ids(session, [])
|
||||
|
||||
assert deleted == 0
|
||||
assert session.query(Asset).count() == 1
|
||||
|
||||
|
||||
class TestGetCacheStatesForPrefixes:
|
||||
def test_returns_states_matching_prefix(self, session: Session, tmp_path):
|
||||
asset = _make_asset(session, "hash1")
|
||||
dir1 = tmp_path / "dir1"
|
||||
dir1.mkdir()
|
||||
dir2 = tmp_path / "dir2"
|
||||
dir2.mkdir()
|
||||
|
||||
path1 = str(dir1 / "file.bin")
|
||||
path2 = str(dir2 / "file.bin")
|
||||
|
||||
_make_cache_state(session, asset, path1, mtime_ns=100)
|
||||
_make_cache_state(session, asset, path2, mtime_ns=200)
|
||||
session.commit()
|
||||
|
||||
rows = get_cache_states_for_prefixes(session, [str(dir1)])
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0].file_path == path1
|
||||
|
||||
def test_empty_prefixes_returns_empty(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_cache_state(session, asset, "/some/path.bin")
|
||||
session.commit()
|
||||
|
||||
rows = get_cache_states_for_prefixes(session, [])
|
||||
|
||||
assert rows == []
|
||||
|
||||
|
||||
class TestBulkSetNeedsVerify:
|
||||
def test_sets_needs_verify_flag(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
state1 = _make_cache_state(session, asset, "/path1.bin", needs_verify=False)
|
||||
state2 = _make_cache_state(session, asset, "/path2.bin", needs_verify=False)
|
||||
session.commit()
|
||||
|
||||
updated = bulk_update_needs_verify(session, [state1.id, state2.id], True)
|
||||
session.commit()
|
||||
|
||||
assert updated == 2
|
||||
session.refresh(state1)
|
||||
session.refresh(state2)
|
||||
assert state1.needs_verify is True
|
||||
assert state2.needs_verify is True
|
||||
|
||||
def test_empty_list_updates_nothing(self, session: Session):
|
||||
updated = bulk_update_needs_verify(session, [], True)
|
||||
assert updated == 0
|
||||
|
||||
|
||||
class TestDeleteCacheStatesByIds:
|
||||
def test_deletes_states_by_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
state1 = _make_cache_state(session, asset, "/path1.bin")
|
||||
_make_cache_state(session, asset, "/path2.bin")
|
||||
session.commit()
|
||||
|
||||
deleted = delete_cache_states_by_ids(session, [state1.id])
|
||||
session.commit()
|
||||
|
||||
assert deleted == 1
|
||||
assert session.query(AssetCacheState).count() == 1
|
||||
|
||||
def test_empty_list_deletes_nothing(self, session: Session):
|
||||
deleted = delete_cache_states_by_ids(session, [])
|
||||
assert deleted == 0
|
||||
|
||||
|
||||
class TestDeleteOrphanedSeedAsset:
|
||||
@pytest.mark.parametrize(
|
||||
"create_asset,expected_deleted,expected_count",
|
||||
[
|
||||
(True, True, 0), # Existing asset gets deleted
|
||||
(False, False, 0), # Nonexistent returns False
|
||||
],
|
||||
ids=["deletes_existing", "nonexistent_returns_false"],
|
||||
)
|
||||
def test_delete_orphaned_seed_asset(
|
||||
self, session: Session, create_asset, expected_deleted, expected_count
|
||||
):
|
||||
asset_id = "nonexistent-id"
|
||||
if create_asset:
|
||||
asset = _make_asset(session, hash_val=None)
|
||||
asset_id = asset.id
|
||||
now = get_utc_now()
|
||||
info = AssetInfo(
|
||||
owner_id="", name="test", asset_id=asset.id,
|
||||
created_at=now, updated_at=now, last_access_time=now
|
||||
)
|
||||
session.add(info)
|
||||
session.commit()
|
||||
|
||||
deleted = delete_orphaned_seed_asset(session, asset_id)
|
||||
if create_asset:
|
||||
session.commit()
|
||||
|
||||
assert deleted is expected_deleted
|
||||
assert session.query(Asset).count() == expected_count
|
||||
|
||||
|
||||
class TestBulkInsertCacheStatesIgnoreConflicts:
|
||||
def test_inserts_multiple_states(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
rows = [
|
||||
{"asset_id": asset.id, "file_path": "/bulk1.bin", "mtime_ns": 100},
|
||||
{"asset_id": asset.id, "file_path": "/bulk2.bin", "mtime_ns": 200},
|
||||
]
|
||||
bulk_insert_cache_states_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
assert session.query(AssetCacheState).count() == 2
|
||||
|
||||
def test_ignores_conflicts(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_cache_state(session, asset, "/existing.bin", mtime_ns=100)
|
||||
session.commit()
|
||||
|
||||
rows = [
|
||||
{"asset_id": asset.id, "file_path": "/existing.bin", "mtime_ns": 999},
|
||||
{"asset_id": asset.id, "file_path": "/new.bin", "mtime_ns": 200},
|
||||
]
|
||||
bulk_insert_cache_states_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
assert session.query(AssetCacheState).count() == 2
|
||||
existing = session.query(AssetCacheState).filter_by(file_path="/existing.bin").one()
|
||||
assert existing.mtime_ns == 100 # Original value preserved
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
bulk_insert_cache_states_ignore_conflicts(session, [])
|
||||
assert session.query(AssetCacheState).count() == 0
|
||||
|
||||
|
||||
class TestGetCacheStatesByPathsAndAssetIds:
|
||||
def test_returns_matching_paths(self, session: Session):
|
||||
asset1 = _make_asset(session, "hash1")
|
||||
asset2 = _make_asset(session, "hash2")
|
||||
|
||||
_make_cache_state(session, asset1, "/path1.bin")
|
||||
_make_cache_state(session, asset2, "/path2.bin")
|
||||
session.commit()
|
||||
|
||||
path_to_asset = {
|
||||
"/path1.bin": asset1.id,
|
||||
"/path2.bin": asset2.id,
|
||||
}
|
||||
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
|
||||
|
||||
assert winners == {"/path1.bin", "/path2.bin"}
|
||||
|
||||
def test_excludes_non_matching_asset_ids(self, session: Session):
|
||||
asset1 = _make_asset(session, "hash1")
|
||||
asset2 = _make_asset(session, "hash2")
|
||||
|
||||
_make_cache_state(session, asset1, "/path1.bin")
|
||||
session.commit()
|
||||
|
||||
# Path exists but with different asset_id
|
||||
path_to_asset = {"/path1.bin": asset2.id}
|
||||
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
|
||||
|
||||
assert winners == set()
|
||||
|
||||
def test_empty_dict_returns_empty(self, session: Session):
|
||||
winners = get_cache_states_by_paths_and_asset_ids(session, {})
|
||||
assert winners == set()
|
||||
184
tests-unit/assets_test/queries/test_metadata.py
Normal file
184
tests-unit/assets_test/queries/test_metadata.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Tests for metadata filtering logic in asset_info queries."""
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
|
||||
from app.assets.database.queries import list_asset_infos_page
|
||||
from app.assets.database.queries.asset_info import convert_metadata_to_rows
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_asset_info(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str,
|
||||
metadata: dict | None = None,
|
||||
) -> AssetInfo:
|
||||
now = get_utc_now()
|
||||
info = AssetInfo(
|
||||
owner_id="",
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
user_metadata=metadata,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
|
||||
if metadata:
|
||||
for key, val in metadata.items():
|
||||
for row in convert_metadata_to_rows(key, val):
|
||||
meta_row = AssetInfoMeta(
|
||||
asset_info_id=info.id,
|
||||
key=row["key"],
|
||||
ordinal=row.get("ordinal", 0),
|
||||
val_str=row.get("val_str"),
|
||||
val_num=row.get("val_num"),
|
||||
val_bool=row.get("val_bool"),
|
||||
val_json=row.get("val_json"),
|
||||
)
|
||||
session.add(meta_row)
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class TestMetadataFilterByType:
|
||||
"""Table-driven tests for metadata filtering by different value types."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"match_meta,nomatch_meta,filter_key,filter_val",
|
||||
[
|
||||
# String matching
|
||||
({"category": "models"}, {"category": "images"}, "category", "models"),
|
||||
# Integer matching
|
||||
({"epoch": 5}, {"epoch": 10}, "epoch", 5),
|
||||
# Float matching
|
||||
({"score": 0.95}, {"score": 0.5}, "score", 0.95),
|
||||
# Boolean True matching
|
||||
({"enabled": True}, {"enabled": False}, "enabled", True),
|
||||
# Boolean False matching
|
||||
({"enabled": False}, {"enabled": True}, "enabled", False),
|
||||
],
|
||||
ids=["string", "int", "float", "bool_true", "bool_false"],
|
||||
)
|
||||
def test_filter_matches_correct_value(
|
||||
self, session: Session, match_meta, nomatch_meta, filter_key, filter_val
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, "match", match_meta)
|
||||
_make_asset_info(session, asset, "nomatch", nomatch_meta)
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(
|
||||
session, metadata_filter={filter_key: filter_val}
|
||||
)
|
||||
assert total == 1
|
||||
assert infos[0].name == "match"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stored_meta,filter_key,filter_val",
|
||||
[
|
||||
# String no match
|
||||
({"category": "models"}, "category", "other"),
|
||||
# Int no match
|
||||
({"epoch": 5}, "epoch", 99),
|
||||
# Float no match
|
||||
({"score": 0.5}, "score", 0.99),
|
||||
],
|
||||
ids=["string_no_match", "int_no_match", "float_no_match"],
|
||||
)
|
||||
def test_filter_returns_empty_when_no_match(
|
||||
self, session: Session, stored_meta, filter_key, filter_val
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, "item", stored_meta)
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(
|
||||
session, metadata_filter={filter_key: filter_val}
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestMetadataFilterNull:
|
||||
"""Tests for null/missing key filtering."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"match_name,match_meta,nomatch_name,nomatch_meta,filter_key",
|
||||
[
|
||||
# Null matches missing key
|
||||
("missing_key", {}, "has_key", {"optional": "value"}, "optional"),
|
||||
# Null matches explicit null
|
||||
("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"),
|
||||
],
|
||||
ids=["missing_key", "explicit_null"],
|
||||
)
|
||||
def test_null_filter_matches(
|
||||
self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, match_name, match_meta)
|
||||
_make_asset_info(session, asset, nomatch_name, nomatch_meta)
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(session, metadata_filter={filter_key: None})
|
||||
assert total == 1
|
||||
assert infos[0].name == match_name
|
||||
|
||||
|
||||
class TestMetadataFilterList:
|
||||
"""Tests for list-based (OR) filtering."""
|
||||
|
||||
def test_filter_by_list_matches_any(self, session: Session):
|
||||
"""List values should match ANY of the values (OR)."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, "cat_a", {"category": "a"})
|
||||
_make_asset_info(session, asset, "cat_b", {"category": "b"})
|
||||
_make_asset_info(session, asset, "cat_c", {"category": "c"})
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(session, metadata_filter={"category": ["a", "b"]})
|
||||
assert total == 2
|
||||
names = {i.name for i in infos}
|
||||
assert names == {"cat_a", "cat_b"}
|
||||
|
||||
|
||||
class TestMetadataFilterMultipleKeys:
|
||||
"""Tests for multiple filter keys (AND semantics)."""
|
||||
|
||||
def test_multiple_keys_must_all_match(self, session: Session):
|
||||
"""Multiple keys should ALL match (AND)."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, "match", {"type": "model", "version": 2})
|
||||
_make_asset_info(session, asset, "wrong_type", {"type": "config", "version": 2})
|
||||
_make_asset_info(session, asset, "wrong_version", {"type": "model", "version": 1})
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(
|
||||
session, metadata_filter={"type": "model", "version": 2}
|
||||
)
|
||||
assert total == 1
|
||||
assert infos[0].name == "match"
|
||||
|
||||
|
||||
class TestMetadataFilterEmptyDict:
|
||||
"""Tests for empty filter behavior."""
|
||||
|
||||
def test_empty_filter_returns_all(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_asset_info(session, asset, "a", {"key": "val"})
|
||||
_make_asset_info(session, asset, "b", {})
|
||||
session.commit()
|
||||
|
||||
infos, _, total = list_asset_infos_page(session, metadata_filter={})
|
||||
assert total == 2
|
||||
366
tests-unit/assets_test/queries/test_tags.py
Normal file
366
tests-unit/assets_test/queries/test_tags.py
Normal file
@@ -0,0 +1,366 @@
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoTag, AssetInfoMeta, Tag
|
||||
from app.assets.database.queries import (
|
||||
ensure_tags_exist,
|
||||
get_asset_tags,
|
||||
set_asset_info_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
add_missing_tag_for_asset_id,
|
||||
remove_missing_tag_for_asset_id,
|
||||
list_tags_with_usage,
|
||||
bulk_insert_tags_and_meta,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_asset_info(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetInfo:
|
||||
now = get_utc_now()
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
return info
|
||||
|
||||
|
||||
class TestEnsureTagsExist:
|
||||
def test_creates_new_tags(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
|
||||
session.commit()
|
||||
|
||||
tags = session.query(Tag).all()
|
||||
assert {t.name for t in tags} == {"alpha", "beta"}
|
||||
|
||||
def test_is_idempotent(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
session.commit()
|
||||
|
||||
assert session.query(Tag).count() == 1
|
||||
|
||||
def test_normalizes_tags(self, session: Session):
|
||||
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
|
||||
session.commit()
|
||||
|
||||
tags = session.query(Tag).all()
|
||||
assert {t.name for t in tags} == {"alpha", "beta"}
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
ensure_tags_exist(session, [])
|
||||
session.commit()
|
||||
assert session.query(Tag).count() == 0
|
||||
|
||||
def test_tag_type_is_set(self, session: Session):
|
||||
ensure_tags_exist(session, ["system-tag"], tag_type="system")
|
||||
session.commit()
|
||||
|
||||
tag = session.query(Tag).filter_by(name="system-tag").one()
|
||||
assert tag.tag_type == "system"
|
||||
|
||||
|
||||
class TestGetAssetTags:
|
||||
def test_returns_empty_for_no_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
tags = get_asset_tags(session, asset_info_id=info.id)
|
||||
assert tags == []
|
||||
|
||||
def test_returns_tags_for_asset(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
ensure_tags_exist(session, ["tag1", "tag2"])
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=info.id, tag_name="tag1", origin="manual", added_at=get_utc_now()),
|
||||
AssetInfoTag(asset_info_id=info.id, tag_name="tag2", origin="manual", added_at=get_utc_now()),
|
||||
])
|
||||
session.flush()
|
||||
|
||||
tags = get_asset_tags(session, asset_info_id=info.id)
|
||||
assert set(tags) == {"tag1", "tag2"}
|
||||
|
||||
|
||||
class TestSetAssetInfoTags:
|
||||
def test_adds_new_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
|
||||
session.commit()
|
||||
|
||||
assert set(result["added"]) == {"a", "b"}
|
||||
assert result["removed"] == []
|
||||
assert set(result["total"]) == {"a", "b"}
|
||||
|
||||
def test_removes_old_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b", "c"])
|
||||
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a"])
|
||||
session.commit()
|
||||
|
||||
assert result["added"] == []
|
||||
assert set(result["removed"]) == {"b", "c"}
|
||||
assert result["total"] == ["a"]
|
||||
|
||||
def test_replaces_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
|
||||
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["b", "c"])
|
||||
session.commit()
|
||||
|
||||
assert result["added"] == ["c"]
|
||||
assert result["removed"] == ["a"]
|
||||
assert set(result["total"]) == {"b", "c"}
|
||||
|
||||
|
||||
class TestAddTagsToAssetInfo:
|
||||
def test_adds_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
|
||||
session.commit()
|
||||
|
||||
assert set(result["added"]) == {"x", "y"}
|
||||
assert result["already_present"] == []
|
||||
|
||||
def test_reports_already_present(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x"])
|
||||
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
|
||||
session.commit()
|
||||
|
||||
assert result["added"] == ["y"]
|
||||
assert result["already_present"] == ["x"]
|
||||
|
||||
def test_raises_for_missing_asset_info(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
add_tags_to_asset_info(session, asset_info_id="nonexistent", tags=["x"])
|
||||
|
||||
|
||||
class TestRemoveTagsFromAssetInfo:
|
||||
def test_removes_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
|
||||
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "b"])
|
||||
session.commit()
|
||||
|
||||
assert set(result["removed"]) == {"a", "b"}
|
||||
assert result["not_present"] == []
|
||||
assert result["total_tags"] == ["c"]
|
||||
|
||||
def test_reports_not_present(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a"])
|
||||
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "x"])
|
||||
session.commit()
|
||||
|
||||
assert result["removed"] == ["a"]
|
||||
assert result["not_present"] == ["x"]
|
||||
|
||||
def test_raises_for_missing_asset_info(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
remove_tags_from_asset_info(session, asset_info_id="nonexistent", tags=["x"])
|
||||
|
||||
|
||||
class TestMissingTagFunctions:
|
||||
def test_add_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
|
||||
tags = get_asset_tags(session, asset_info_id=info.id)
|
||||
assert "missing" in tags
|
||||
|
||||
def test_add_missing_tag_is_idempotent(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
|
||||
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="missing").all()
|
||||
assert len(links) == 1
|
||||
|
||||
def test_remove_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
|
||||
tags = get_asset_tags(session, asset_info_id=info.id)
|
||||
assert "missing" not in tags
|
||||
|
||||
|
||||
class TestListTagsWithUsage:
|
||||
def test_returns_tags_with_counts(self, session: Session):
|
||||
ensure_tags_exist(session, ["used", "unused"])
|
||||
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags_with_usage(session)
|
||||
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
assert tag_dict["used"] == 1
|
||||
assert tag_dict["unused"] == 0
|
||||
assert total == 2
|
||||
|
||||
def test_exclude_zero_counts(self, session: Session):
|
||||
ensure_tags_exist(session, ["used", "unused"])
|
||||
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags_with_usage(session, include_zero=False)
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
assert "used" in tag_names
|
||||
assert "unused" not in tag_names
|
||||
|
||||
def test_prefix_filter(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags_with_usage(session, prefix="alph")
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_order_by_name(self, session: Session):
|
||||
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
|
||||
session.commit()
|
||||
|
||||
rows, _ = list_tags_with_usage(session, order="name_asc")
|
||||
|
||||
names = [name for name, _, _ in rows]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_owner_visibility(self, session: Session):
|
||||
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
|
||||
|
||||
asset = _make_asset(session, "hash1")
|
||||
shared_info = _make_asset_info(session, asset, name="shared", owner_id="")
|
||||
owner_info = _make_asset_info(session, asset, name="owned", owner_id="user1")
|
||||
|
||||
add_tags_to_asset_info(session, asset_info_id=shared_info.id, tags=["shared-tag"])
|
||||
add_tags_to_asset_info(session, asset_info_id=owner_info.id, tags=["owner-tag"])
|
||||
session.commit()
|
||||
|
||||
# Empty owner sees only shared
|
||||
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 0
|
||||
|
||||
# User1 sees both
|
||||
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 1
|
||||
|
||||
|
||||
class TestBulkInsertTagsAndMeta:
|
||||
def test_inserts_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"])
|
||||
session.commit()
|
||||
|
||||
now = get_utc_now()
|
||||
tag_rows = [
|
||||
{"asset_info_id": info.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now},
|
||||
{"asset_info_id": info.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now},
|
||||
]
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
|
||||
session.commit()
|
||||
|
||||
tags = get_asset_tags(session, asset_info_id=info.id)
|
||||
assert set(tags) == {"bulk-tag1", "bulk-tag2"}
|
||||
|
||||
def test_inserts_meta(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
meta_rows = [
|
||||
{
|
||||
"asset_info_id": info.id,
|
||||
"key": "meta-key",
|
||||
"ordinal": 0,
|
||||
"val_str": "meta-value",
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
},
|
||||
]
|
||||
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows)
|
||||
session.commit()
|
||||
|
||||
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
|
||||
assert len(meta) == 1
|
||||
assert meta[0].key == "meta-key"
|
||||
assert meta[0].val_str == "meta-value"
|
||||
|
||||
def test_ignores_conflicts(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["existing-tag"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing-tag"])
|
||||
session.commit()
|
||||
|
||||
now = get_utc_now()
|
||||
tag_rows = [
|
||||
{"asset_info_id": info.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now},
|
||||
]
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
|
||||
session.commit()
|
||||
|
||||
# Should still have only one tag link
|
||||
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="existing-tag").all()
|
||||
assert len(links) == 1
|
||||
# Origin should be original, not overwritten
|
||||
assert links[0].origin == "manual"
|
||||
|
||||
def test_empty_lists_is_noop(self, session: Session):
|
||||
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[])
|
||||
assert session.query(AssetInfoTag).count() == 0
|
||||
assert session.query(AssetInfoMeta).count() == 0
|
||||
1
tests-unit/assets_test/services/__init__.py
Normal file
1
tests-unit/assets_test/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Service layer tests
|
||||
48
tests-unit/assets_test/services/conftest.py
Normal file
48
tests-unit/assets_test/services/conftest.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
"""In-memory SQLite engine for fast unit tests."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(db_engine):
|
||||
"""Session fixture for tests that need direct DB access."""
|
||||
with Session(db_engine) as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_create_session(db_engine):
|
||||
"""Patch create_session to use our in-memory database."""
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy.orm import Session as SASession
|
||||
|
||||
@contextmanager
|
||||
def _create_session():
|
||||
with SASession(db_engine) as sess:
|
||||
yield sess
|
||||
|
||||
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||
patch("app.assets.services.asset_management.create_session", _create_session), \
|
||||
patch("app.assets.services.tagging.create_session", _create_session):
|
||||
yield _create_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Temporary directory for file operations."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
264
tests-unit/assets_test/services/test_asset_management.py
Normal file
264
tests-unit/assets_test/services/test_asset_management.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Tests for asset_management services."""
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo
|
||||
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.services import (
|
||||
get_asset_detail,
|
||||
update_asset_metadata,
|
||||
delete_asset_reference,
|
||||
set_asset_preview,
|
||||
)
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_asset_info(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str = "test",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
now = get_utc_now()
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
return info
|
||||
|
||||
|
||||
class TestGetAssetDetail:
|
||||
def test_returns_none_for_nonexistent(self, mock_create_session):
|
||||
result = get_asset_detail(asset_info_id="nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_returns_asset_with_tags(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, name="test.bin")
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
|
||||
session.commit()
|
||||
|
||||
result = get_asset_detail(asset_info_id=info.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.info.id == info.id
|
||||
assert result.asset.hash == asset.hash
|
||||
assert set(result.tags) == {"alpha", "beta"}
|
||||
|
||||
def test_respects_owner_visibility(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
# Wrong owner cannot see
|
||||
result = get_asset_detail(asset_info_id=info.id, owner_id="user2")
|
||||
assert result is None
|
||||
|
||||
# Correct owner can see
|
||||
result = get_asset_detail(asset_info_id=info.id, owner_id="user1")
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestUpdateAssetMetadata:
|
||||
def test_updates_name(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, name="old_name.bin")
|
||||
info_id = info.id
|
||||
session.commit()
|
||||
|
||||
update_asset_metadata(
|
||||
asset_info_id=info_id,
|
||||
name="new_name.bin",
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
session.expire_all()
|
||||
updated_info = session.get(AssetInfo, info_id)
|
||||
assert updated_info.name == "new_name.bin"
|
||||
|
||||
def test_updates_tags(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["old"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["old"])
|
||||
session.commit()
|
||||
|
||||
result = update_asset_metadata(
|
||||
asset_info_id=info.id,
|
||||
tags=["new1", "new2"],
|
||||
)
|
||||
|
||||
assert set(result.tags) == {"new1", "new2"}
|
||||
assert "old" not in result.tags
|
||||
|
||||
def test_updates_user_metadata(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
info_id = info.id
|
||||
session.commit()
|
||||
|
||||
update_asset_metadata(
|
||||
asset_info_id=info_id,
|
||||
user_metadata={"key": "value", "num": 42},
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
session.expire_all()
|
||||
updated_info = session.get(AssetInfo, info_id)
|
||||
assert updated_info.user_metadata["key"] == "value"
|
||||
assert updated_info.user_metadata["num"] == 42
|
||||
|
||||
def test_raises_for_nonexistent(self, mock_create_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
update_asset_metadata(asset_info_id="nonexistent", name="fail")
|
||||
|
||||
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(PermissionError, match="not owner"):
|
||||
update_asset_metadata(
|
||||
asset_info_id=info.id,
|
||||
name="new",
|
||||
owner_id="user2",
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteAssetReference:
|
||||
def test_deletes_asset_info(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
info_id = info.id
|
||||
session.commit()
|
||||
|
||||
result = delete_asset_reference(
|
||||
asset_info_id=info_id,
|
||||
owner_id="",
|
||||
delete_content_if_orphan=False,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert session.get(AssetInfo, info_id) is None
|
||||
|
||||
def test_returns_false_for_nonexistent(self, mock_create_session):
|
||||
result = delete_asset_reference(
|
||||
asset_info_id="nonexistent",
|
||||
owner_id="",
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, owner_id="user1")
|
||||
info_id = info.id
|
||||
session.commit()
|
||||
|
||||
result = delete_asset_reference(
|
||||
asset_info_id=info_id,
|
||||
owner_id="user2",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
assert session.get(AssetInfo, info_id) is not None
|
||||
|
||||
def test_keeps_asset_if_other_infos_exist(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info1 = _make_asset_info(session, asset, name="info1")
|
||||
_make_asset_info(session, asset, name="info2") # Second info keeps asset alive
|
||||
asset_id = asset.id
|
||||
session.commit()
|
||||
|
||||
delete_asset_reference(
|
||||
asset_info_id=info1.id,
|
||||
owner_id="",
|
||||
delete_content_if_orphan=True,
|
||||
)
|
||||
|
||||
# Asset should still exist
|
||||
assert session.get(Asset, asset_id) is not None
|
||||
|
||||
def test_deletes_orphaned_asset(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
asset_id = asset.id
|
||||
info_id = info.id
|
||||
session.commit()
|
||||
|
||||
delete_asset_reference(
|
||||
asset_info_id=info_id,
|
||||
owner_id="",
|
||||
delete_content_if_orphan=True,
|
||||
)
|
||||
|
||||
# Both info and asset should be gone
|
||||
assert session.get(AssetInfo, info_id) is None
|
||||
assert session.get(Asset, asset_id) is None
|
||||
|
||||
|
||||
class TestSetAssetPreview:
|
||||
def test_sets_preview(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session, hash_val="blake3:main")
|
||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||
info = _make_asset_info(session, asset)
|
||||
info_id = info.id
|
||||
preview_id = preview_asset.id
|
||||
session.commit()
|
||||
|
||||
set_asset_preview(
|
||||
asset_info_id=info_id,
|
||||
preview_asset_id=preview_id,
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
session.expire_all()
|
||||
updated_info = session.get(AssetInfo, info_id)
|
||||
assert updated_info.preview_id == preview_id
|
||||
|
||||
def test_clears_preview(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||
info = _make_asset_info(session, asset)
|
||||
info.preview_id = preview_asset.id
|
||||
info_id = info.id
|
||||
session.commit()
|
||||
|
||||
set_asset_preview(
|
||||
asset_info_id=info_id,
|
||||
preview_asset_id=None,
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
session.expire_all()
|
||||
updated_info = session.get(AssetInfo, info_id)
|
||||
assert updated_info.preview_id is None
|
||||
|
||||
def test_raises_for_nonexistent_info(self, mock_create_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_asset_preview(asset_info_id="nonexistent")
|
||||
|
||||
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(PermissionError, match="not owner"):
|
||||
set_asset_preview(
|
||||
asset_info_id=info.id,
|
||||
preview_asset_id=None,
|
||||
owner_id="user2",
|
||||
)
|
||||
138
tests-unit/assets_test/services/test_bulk_ingest.py
Normal file
138
tests-unit/assets_test/services/test_bulk_ingest.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Tests for bulk ingest services."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
|
||||
|
||||
|
||||
class TestBatchInsertSeedAssets:
|
||||
def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path):
|
||||
"""Verify mime_type is stored in the Asset table for model files."""
|
||||
file_path = temp_dir / "model.safetensors"
|
||||
file_path.write_bytes(b"fake safetensors content")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 24,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "Test Model",
|
||||
"tags": ["models"],
|
||||
"fname": "model.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
}
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_infos == 1
|
||||
|
||||
# Verify Asset has mime_type populated
|
||||
assets = session.query(Asset).all()
|
||||
assert len(assets) == 1
|
||||
assert assets[0].mime_type == "application/safetensors"
|
||||
|
||||
def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path):
|
||||
"""Verify mime_type is None when not provided in spec."""
|
||||
file_path = temp_dir / "unknown.bin"
|
||||
file_path.write_bytes(b"binary data")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 11,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "Unknown File",
|
||||
"tags": [],
|
||||
"fname": "unknown.bin",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": None,
|
||||
}
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_infos == 1
|
||||
|
||||
assets = session.query(Asset).all()
|
||||
assert len(assets) == 1
|
||||
assert assets[0].mime_type is None
|
||||
|
||||
def test_various_model_mime_types(self, session: Session, temp_dir: Path):
|
||||
"""Verify various model file types get correct mime_type."""
|
||||
test_cases = [
|
||||
("model.safetensors", "application/safetensors"),
|
||||
("model.pt", "application/pytorch"),
|
||||
("model.ckpt", "application/pickle"),
|
||||
("model.gguf", "application/gguf"),
|
||||
]
|
||||
|
||||
specs: list[SeedAssetSpec] = []
|
||||
for filename, mime_type in test_cases:
|
||||
file_path = temp_dir / filename
|
||||
file_path.write_bytes(b"content")
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 7,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": filename,
|
||||
"tags": [],
|
||||
"fname": filename,
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
)
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_infos == len(test_cases)
|
||||
|
||||
for filename, expected_mime in test_cases:
|
||||
from app.assets.database.models import AssetInfo
|
||||
info = session.query(AssetInfo).filter_by(name=filename).first()
|
||||
assert info is not None
|
||||
asset = session.query(Asset).filter_by(id=info.asset_id).first()
|
||||
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
|
||||
|
||||
|
||||
class TestMetadataExtraction:
|
||||
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
|
||||
"""Verify metadata extraction returns correct mime_type for model files."""
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
|
||||
file_path = temp_dir / "model.safetensors"
|
||||
file_path.write_bytes(b"fake safetensors content")
|
||||
|
||||
meta = extract_file_metadata(str(file_path))
|
||||
|
||||
assert meta.content_type == "application/safetensors"
|
||||
|
||||
def test_mime_type_for_various_model_formats(self, temp_dir: Path):
|
||||
"""Verify various model file types get correct mime_type from metadata."""
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
|
||||
test_cases = [
|
||||
("model.safetensors", "application/safetensors"),
|
||||
("model.sft", "application/safetensors"),
|
||||
("model.pt", "application/pytorch"),
|
||||
("model.pth", "application/pytorch"),
|
||||
("model.ckpt", "application/pickle"),
|
||||
("model.pkl", "application/pickle"),
|
||||
("model.gguf", "application/gguf"),
|
||||
]
|
||||
|
||||
for filename, expected_mime in test_cases:
|
||||
file_path = temp_dir / filename
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
meta = extract_file_metadata(str(file_path))
|
||||
|
||||
assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}"
|
||||
227
tests-unit/assets_test/services/test_ingest.py
Normal file
227
tests-unit/assets_test/services/test_ingest.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Tests for ingest services."""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, Tag
|
||||
from app.assets.database.queries import get_asset_tags
|
||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
||||
|
||||
|
||||
class TestIngestFileFromPath:
|
||||
def test_creates_asset_and_cache_state(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "test_file.bin"
|
||||
file_path.write_bytes(b"test content")
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:abc123",
|
||||
size_bytes=12,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="application/octet-stream",
|
||||
)
|
||||
|
||||
assert result.asset_created is True
|
||||
assert result.state_created is True
|
||||
assert result.asset_info_id is None # no info_name provided
|
||||
|
||||
# Verify DB state
|
||||
assets = session.query(Asset).all()
|
||||
assert len(assets) == 1
|
||||
assert assets[0].hash == "blake3:abc123"
|
||||
|
||||
states = session.query(AssetCacheState).all()
|
||||
assert len(states) == 1
|
||||
assert states[0].file_path == str(file_path)
|
||||
|
||||
def test_creates_asset_info_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "model.safetensors"
|
||||
file_path.write_bytes(b"model data")
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:def456",
|
||||
size_bytes=10,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="application/octet-stream",
|
||||
info_name="My Model",
|
||||
owner_id="user1",
|
||||
)
|
||||
|
||||
assert result.asset_created is True
|
||||
assert result.asset_info_id is not None
|
||||
|
||||
info = session.query(AssetInfo).first()
|
||||
assert info is not None
|
||||
assert info.name == "My Model"
|
||||
assert info.owner_id == "user1"
|
||||
|
||||
def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "tagged.bin"
|
||||
file_path.write_bytes(b"data")
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:ghi789",
|
||||
size_bytes=4,
|
||||
mtime_ns=1234567890000000000,
|
||||
info_name="Tagged Asset",
|
||||
tags=["models", "checkpoints"],
|
||||
)
|
||||
|
||||
assert result.asset_info_id is not None
|
||||
|
||||
# Verify tags were created and linked
|
||||
tags = session.query(Tag).all()
|
||||
tag_names = {t.name for t in tags}
|
||||
assert "models" in tag_names
|
||||
assert "checkpoints" in tag_names
|
||||
|
||||
asset_tags = get_asset_tags(session, asset_info_id=result.asset_info_id)
|
||||
assert set(asset_tags) == {"models", "checkpoints"}
|
||||
|
||||
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "dup.bin"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
# First ingest
|
||||
r1 = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:repeat",
|
||||
size_bytes=7,
|
||||
mtime_ns=1234567890000000000,
|
||||
)
|
||||
assert r1.asset_created is True
|
||||
|
||||
# Second ingest with same hash - should update, not create
|
||||
r2 = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:repeat",
|
||||
size_bytes=7,
|
||||
mtime_ns=1234567890000000001, # different mtime
|
||||
)
|
||||
assert r2.asset_created is False
|
||||
assert r2.state_updated is True or r2.state_created is False
|
||||
|
||||
# Still only one asset
|
||||
assets = session.query(Asset).all()
|
||||
assert len(assets) == 1
|
||||
|
||||
def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "with_preview.bin"
|
||||
file_path.write_bytes(b"data")
|
||||
|
||||
# Create a preview asset first
|
||||
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
||||
session.add(preview_asset)
|
||||
session.commit()
|
||||
preview_id = preview_asset.id
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:main",
|
||||
size_bytes=4,
|
||||
mtime_ns=1234567890000000000,
|
||||
info_name="With Preview",
|
||||
preview_id=preview_id,
|
||||
)
|
||||
|
||||
assert result.asset_info_id is not None
|
||||
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
|
||||
assert info.preview_id == preview_id
|
||||
|
||||
def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "bad_preview.bin"
|
||||
file_path.write_bytes(b"data")
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:badpreview",
|
||||
size_bytes=4,
|
||||
mtime_ns=1234567890000000000,
|
||||
info_name="Bad Preview",
|
||||
preview_id="nonexistent-uuid",
|
||||
)
|
||||
|
||||
assert result.asset_info_id is not None
|
||||
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
|
||||
assert info.preview_id is None
|
||||
|
||||
|
||||
class TestRegisterExistingAsset:
|
||||
def test_creates_info_for_existing_asset(self, mock_create_session, session: Session):
|
||||
# Create existing asset
|
||||
asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:existing",
|
||||
name="Registered Asset",
|
||||
user_metadata={"key": "value"},
|
||||
tags=["models"],
|
||||
)
|
||||
|
||||
assert result.created is True
|
||||
assert "models" in result.tags
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
session.expire_all()
|
||||
infos = session.query(AssetInfo).filter_by(name="Registered Asset").all()
|
||||
assert len(infos) == 1
|
||||
|
||||
def test_returns_existing_info(self, mock_create_session, session: Session):
|
||||
# Create asset and info
|
||||
asset = Asset(hash="blake3:withinfo", size_bytes=512)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
|
||||
from app.assets.helpers import get_utc_now
|
||||
info = AssetInfo(
|
||||
owner_id="",
|
||||
name="Existing Info",
|
||||
asset_id=asset.id,
|
||||
created_at=get_utc_now(),
|
||||
updated_at=get_utc_now(),
|
||||
last_access_time=get_utc_now(),
|
||||
)
|
||||
session.add(info)
|
||||
session.flush() # Flush to get the ID
|
||||
info_id = info.id
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:withinfo",
|
||||
name="Existing Info",
|
||||
owner_id="",
|
||||
)
|
||||
|
||||
assert result.created is False
|
||||
|
||||
# Verify only one AssetInfo exists for this name
|
||||
session.expire_all()
|
||||
infos = session.query(AssetInfo).filter_by(name="Existing Info").all()
|
||||
assert len(infos) == 1
|
||||
assert infos[0].id == info_id
|
||||
|
||||
def test_raises_for_nonexistent_hash(self, mock_create_session):
|
||||
with pytest.raises(ValueError, match="No asset with hash"):
|
||||
_register_existing_asset(
|
||||
asset_hash="blake3:doesnotexist",
|
||||
name="Fail",
|
||||
)
|
||||
|
||||
def test_applies_tags_to_new_info(self, mock_create_session, session: Session):
|
||||
asset = Asset(hash="blake3:tagged", size_bytes=256)
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:tagged",
|
||||
name="Tagged Info",
|
||||
tags=["alpha", "beta"],
|
||||
)
|
||||
|
||||
assert result.created is True
|
||||
assert set(result.tags) == {"alpha", "beta"}
|
||||
197
tests-unit/assets_test/services/test_tagging.py
Normal file
197
tests-unit/assets_test/services/test_tagging.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Tests for tagging services."""
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo
|
||||
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.services import apply_tags, remove_tags, list_tags
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_asset_info(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str = "test",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
now = get_utc_now()
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
return info
|
||||
|
||||
|
||||
class TestApplyTags:
|
||||
def test_adds_new_tags(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
session.commit()
|
||||
|
||||
result = apply_tags(
|
||||
asset_info_id=info.id,
|
||||
tags=["alpha", "beta"],
|
||||
)
|
||||
|
||||
assert set(result.added) == {"alpha", "beta"}
|
||||
assert result.already_present == []
|
||||
assert set(result.total_tags) == {"alpha", "beta"}
|
||||
|
||||
def test_reports_already_present(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["existing"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing"])
|
||||
session.commit()
|
||||
|
||||
result = apply_tags(
|
||||
asset_info_id=info.id,
|
||||
tags=["existing", "new"],
|
||||
)
|
||||
|
||||
assert result.added == ["new"]
|
||||
assert result.already_present == ["existing"]
|
||||
|
||||
def test_raises_for_nonexistent_info(self, mock_create_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
apply_tags(asset_info_id="nonexistent", tags=["x"])
|
||||
|
||||
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(PermissionError, match="not owner"):
|
||||
apply_tags(
|
||||
asset_info_id=info.id,
|
||||
tags=["new"],
|
||||
owner_id="user2",
|
||||
)
|
||||
|
||||
|
||||
class TestRemoveTags:
|
||||
def test_removes_tags(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["a", "b", "c"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
|
||||
session.commit()
|
||||
|
||||
result = remove_tags(
|
||||
asset_info_id=info.id,
|
||||
tags=["a", "b"],
|
||||
)
|
||||
|
||||
assert set(result.removed) == {"a", "b"}
|
||||
assert result.not_present == []
|
||||
assert result.total_tags == ["c"]
|
||||
|
||||
def test_reports_not_present(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
ensure_tags_exist(session, ["present"])
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["present"])
|
||||
session.commit()
|
||||
|
||||
result = remove_tags(
|
||||
asset_info_id=info.id,
|
||||
tags=["present", "absent"],
|
||||
)
|
||||
|
||||
assert result.removed == ["present"]
|
||||
assert result.not_present == ["absent"]
|
||||
|
||||
def test_raises_for_nonexistent_info(self, mock_create_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
remove_tags(asset_info_id="nonexistent", tags=["x"])
|
||||
|
||||
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset, owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(PermissionError, match="not owner"):
|
||||
remove_tags(
|
||||
asset_info_id=info.id,
|
||||
tags=["x"],
|
||||
owner_id="user2",
|
||||
)
|
||||
|
||||
|
||||
class TestListTags:
|
||||
def test_returns_tags_with_counts(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["used", "unused"])
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags()
|
||||
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
assert tag_dict["used"] == 1
|
||||
assert tag_dict["unused"] == 0
|
||||
assert total == 2
|
||||
|
||||
def test_excludes_zero_counts(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["used", "unused"])
|
||||
asset = _make_asset(session)
|
||||
info = _make_asset_info(session, asset)
|
||||
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags(include_zero=False)
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
assert "used" in tag_names
|
||||
assert "unused" not in tag_names
|
||||
|
||||
def test_prefix_filter(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
|
||||
session.commit()
|
||||
|
||||
rows, _ = list_tags(prefix="alph")
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_order_by_name(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
|
||||
session.commit()
|
||||
|
||||
rows, _ = list_tags(order="name_asc")
|
||||
|
||||
names = [name for name, _, _ in rows]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_pagination(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["a", "b", "c", "d", "e"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags(limit=2, offset=1, order="name_asc")
|
||||
|
||||
assert total == 5
|
||||
assert len(rows) == 2
|
||||
names = [name for name, _, _ in rows]
|
||||
assert names == ["b", "c"]
|
||||
|
||||
def test_clamps_limit(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["a"])
|
||||
session.commit()
|
||||
|
||||
# Service should clamp limit to max 1000
|
||||
rows, _ = list_tags(limit=2000)
|
||||
assert len(rows) <= 1000
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
from helpers import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
from helpers import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_create_from_hash_success(
|
||||
@@ -126,42 +126,52 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api
|
||||
assert body == b""
|
||||
|
||||
|
||||
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
|
||||
bogus = str(uuid.uuid4())
|
||||
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
|
||||
@pytest.mark.parametrize(
|
||||
"method,endpoint_template,payload,expected_status,error_code",
|
||||
[
|
||||
# Delete nonexistent asset
|
||||
("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"),
|
||||
# Bad hash algorithm in from-hash
|
||||
(
|
||||
"post",
|
||||
"/api/assets/from-hash",
|
||||
{"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]},
|
||||
400,
|
||||
"INVALID_BODY",
|
||||
),
|
||||
# Get with bad UUID format
|
||||
("get", "/api/assets/not-a-uuid", None, 404, None),
|
||||
# Get content with bad UUID format
|
||||
("get", "/api/assets/not-a-uuid/content", None, 404, None),
|
||||
],
|
||||
ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"],
|
||||
)
|
||||
def test_error_responses(
|
||||
http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code
|
||||
):
|
||||
# Replace {uuid} placeholder with a random UUID for delete test
|
||||
endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4()))
|
||||
url = f"{api_base}{endpoint}"
|
||||
|
||||
if method == "get":
|
||||
r = http.get(url, timeout=120)
|
||||
elif method == "post":
|
||||
r = http.post(url, json=payload, timeout=120)
|
||||
elif method == "delete":
|
||||
r = http.delete(url, timeout=120)
|
||||
|
||||
assert r.status_code == expected_status
|
||||
if error_code:
|
||||
body = r.json()
|
||||
assert body["error"]["code"] == error_code
|
||||
|
||||
|
||||
def test_create_from_hash_invalid_json(http: requests.Session, api_base: str):
|
||||
"""Invalid JSON body requires special handling (data= instead of json=)."""
|
||||
r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
|
||||
# Bad hash algorithm
|
||||
bad = {
|
||||
"hash": "sha256:" + "0" * 64,
|
||||
"name": "x.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Invalid JSON body
|
||||
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_JSON"
|
||||
|
||||
|
||||
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
|
||||
# All endpoints should be not found, as we UUID regex directly in the route definition.
|
||||
bad_id = "not-a-uuid"
|
||||
|
||||
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
|
||||
assert r1.status_code == 404
|
||||
|
||||
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
|
||||
assert r3.status_code == 404
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_JSON"
|
||||
|
||||
|
||||
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
from helpers import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
|
||||
55
tests-unit/assets_test/test_file_utils.py
Normal file
55
tests-unit/assets_test/test_file_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from app.assets.services.file_utils import is_visible, list_files_recursively
|
||||
|
||||
|
||||
class TestIsVisible:
|
||||
def test_visible_file(self):
|
||||
assert is_visible("file.txt") is True
|
||||
|
||||
def test_hidden_file(self):
|
||||
assert is_visible(".hidden") is False
|
||||
|
||||
def test_hidden_directory(self):
|
||||
assert is_visible(".git") is False
|
||||
|
||||
def test_visible_directory(self):
|
||||
assert is_visible("src") is True
|
||||
|
||||
def test_dotdot_is_hidden(self):
|
||||
assert is_visible("..") is False
|
||||
|
||||
def test_dot_is_hidden(self):
|
||||
assert is_visible(".") is False
|
||||
|
||||
|
||||
class TestListFilesRecursively:
|
||||
def test_skips_hidden_files(self, tmp_path):
|
||||
(tmp_path / "visible.txt").write_text("a")
|
||||
(tmp_path / ".hidden").write_text("b")
|
||||
|
||||
result = list_files_recursively(str(tmp_path))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].endswith("visible.txt")
|
||||
|
||||
def test_skips_hidden_directories(self, tmp_path):
|
||||
hidden_dir = tmp_path / ".hidden_dir"
|
||||
hidden_dir.mkdir()
|
||||
(hidden_dir / "file.txt").write_text("a")
|
||||
|
||||
visible_dir = tmp_path / "visible_dir"
|
||||
visible_dir.mkdir()
|
||||
(visible_dir / "file.txt").write_text("b")
|
||||
|
||||
result = list_files_recursively(str(tmp_path))
|
||||
|
||||
assert len(result) == 1
|
||||
assert "visible_dir" in result[0]
|
||||
assert ".hidden_dir" not in result[0]
|
||||
|
||||
def test_empty_directory(self, tmp_path):
|
||||
result = list_files_recursively(str(tmp_path))
|
||||
assert result == []
|
||||
|
||||
def test_nonexistent_directory(self, tmp_path):
|
||||
result = list_files_recursively(str(tmp_path / "nonexistent"))
|
||||
assert result == []
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@@ -283,30 +284,21 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
|
||||
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
|
||||
# limit too small
|
||||
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
# bad metadata JSON
|
||||
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
@pytest.mark.parametrize(
|
||||
"params,error_code",
|
||||
[
|
||||
({"offset": "-1"}, "INVALID_QUERY"),
|
||||
({"limit": "abc"}, "INVALID_QUERY"),
|
||||
({"limit": "0"}, "INVALID_QUERY"),
|
||||
({"metadata_filter": "{not json"}, "INVALID_QUERY"),
|
||||
],
|
||||
ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"],
|
||||
)
|
||||
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code):
|
||||
r = http.get(api_base + "/api/assets", params=params, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == error_code
|
||||
|
||||
|
||||
def test_list_assets_name_contains_literal_underscore(
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
from helpers import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
423
tests-unit/seeder_test/test_seeder.py
Normal file
423
tests-unit/seeder_test/test_seeder.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Unit tests for the AssetSeeder background scanning class."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import AssetSeeder, Progress, State
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_seeder():
|
||||
"""Create a fresh AssetSeeder instance for testing (bypasses singleton)."""
|
||||
seeder = object.__new__(AssetSeeder)
|
||||
seeder._initialized = False
|
||||
seeder.__init__()
|
||||
yield seeder
|
||||
seeder.shutdown(timeout=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Mock all external dependencies for isolated testing."""
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestSeederStateTransitions:
|
||||
"""Test state machine transitions."""
|
||||
|
||||
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder):
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
def test_start_transitions_to_running(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
started = fresh_seeder.start(roots=("models",))
|
||||
assert started is True
|
||||
status = fresh_seeder.get_status()
|
||||
assert status.state in (State.RUNNING, State.IDLE)
|
||||
|
||||
def test_start_while_running_returns_false(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
second_start = fresh_seeder.start(roots=("models",))
|
||||
assert second_start is False
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_transitions_to_cancelling(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
cancelled = fresh_seeder.cancel()
|
||||
assert cancelled is True
|
||||
assert fresh_seeder.get_status().state == State.CANCELLING
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||
cancelled = fresh_seeder.cancel()
|
||||
assert cancelled is False
|
||||
|
||||
def test_state_returns_to_idle_after_completion(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=5.0)
|
||||
assert completed is True
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
|
||||
class TestSeederWait:
|
||||
"""Test wait() behavior."""
|
||||
|
||||
def test_wait_blocks_until_complete(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=5.0)
|
||||
assert completed is True
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
def test_wait_returns_false_on_timeout(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=10.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=0.1)
|
||||
assert completed is False
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder):
|
||||
completed = fresh_seeder.wait(timeout=1.0)
|
||||
assert completed is True
|
||||
|
||||
|
||||
class TestSeederProgress:
|
||||
"""Test progress tracking."""
|
||||
|
||||
def test_get_status_returns_progress_during_scan(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
progress_seen = []
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return ["/path/file1.safetensors", "/path/file2.safetensors"]
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert status.progress is not None
|
||||
progress_seen.append(status.progress)
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_progress_callback_is_invoked(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
progress_updates: list[Progress] = []
|
||||
|
||||
def callback(p: Progress):
|
||||
progress_updates.append(p)
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots",
|
||||
return_value=[f"/path/file{i}.safetensors" for i in range(10)],
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), progress_callback=callback)
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert len(progress_updates) > 0
|
||||
|
||||
|
||||
class TestSeederCancellation:
|
||||
"""Test cancellation behavior."""
|
||||
|
||||
def test_scan_commits_partial_progress_on_cancellation(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
):
|
||||
insert_count = 0
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_insert(specs, tags):
|
||||
nonlocal insert_count
|
||||
insert_count += 1
|
||||
if insert_count >= 2:
|
||||
barrier.wait(timeout=5.0)
|
||||
return len(specs)
|
||||
|
||||
paths = [f"/path/file{i}.safetensors" for i in range(1500)]
|
||||
specs = [
|
||||
{
|
||||
"abs_path": p,
|
||||
"size_bytes": 100,
|
||||
"mtime_ns": 0,
|
||||
"info_name": f"file{i}",
|
||||
"tags": [],
|
||||
"fname": f"file{i}",
|
||||
}
|
||||
for i, p in enumerate(paths)
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
|
||||
patch(
|
||||
"app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
|
||||
),
|
||||
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.1)
|
||||
|
||||
fresh_seeder.cancel()
|
||||
barrier.set()
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert insert_count >= 1
|
||||
|
||||
|
||||
class TestSeederErrorHandling:
|
||||
"""Test error handling behavior."""
|
||||
|
||||
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch(
|
||||
"app.assets.seeder.collect_paths_for_roots",
|
||||
return_value=["/path/file.safetensors"],
|
||||
),
|
||||
patch(
|
||||
"app.assets.seeder.build_asset_specs",
|
||||
return_value=(
|
||||
[
|
||||
{
|
||||
"abs_path": "/path/file.safetensors",
|
||||
"size_bytes": 100,
|
||||
"mtime_ns": 0,
|
||||
"info_name": "file",
|
||||
"tags": [],
|
||||
"fname": "file",
|
||||
}
|
||||
],
|
||||
set(),
|
||||
0,
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"app.assets.seeder.insert_asset_specs",
|
||||
side_effect=Exception("DB connection failed"),
|
||||
),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert len(status.errors) > 0
|
||||
assert "DB connection failed" in status.errors[0]
|
||||
|
||||
def test_dependencies_unavailable_captured_in_errors(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
):
|
||||
with patch("app.assets.seeder.dependencies_available", return_value=False):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert len(status.errors) > 0
|
||||
assert "dependencies" in status.errors[0].lower()
|
||||
|
||||
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch(
|
||||
"app.assets.seeder.sync_root_safely",
|
||||
side_effect=RuntimeError("Unexpected crash"),
|
||||
),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
status = fresh_seeder.get_status()
|
||||
assert status.state == State.IDLE
|
||||
assert len(status.errors) > 0
|
||||
|
||||
|
||||
class TestSeederThreadSafety:
|
||||
"""Test thread safety of concurrent operations."""
|
||||
|
||||
def test_concurrent_start_calls_spawn_only_one_thread(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
results = []
|
||||
|
||||
def try_start():
|
||||
results.append(fresh_seeder.start(roots=("models",)))
|
||||
|
||||
threads = [threading.Thread(target=try_start) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
barrier.set()
|
||||
|
||||
assert sum(results) == 1
|
||||
|
||||
def test_get_status_safe_during_scan(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
|
||||
statuses = []
|
||||
for _ in range(100):
|
||||
statuses.append(fresh_seeder.get_status())
|
||||
time.sleep(0.001)
|
||||
|
||||
barrier.set()
|
||||
|
||||
assert all(
|
||||
s.state in (State.RUNNING, State.IDLE, State.CANCELLING)
|
||||
for s in statuses
|
||||
)
|
||||
|
||||
|
||||
class TestSeederMarkMissing:
|
||||
"""Test mark_missing_outside_prefixes behavior."""
|
||||
|
||||
def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch(
|
||||
"app.assets.seeder.get_all_known_prefixes",
|
||||
return_value=["/models", "/input", "/output"],
|
||||
),
|
||||
patch(
|
||||
"app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5
|
||||
) as mock_mark,
|
||||
):
|
||||
result = fresh_seeder.mark_missing_outside_prefixes()
|
||||
assert result == 5
|
||||
mock_mark.assert_called_once_with(["/models", "/input", "/output"])
|
||||
|
||||
def test_mark_missing_returns_zero_when_running(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
time.sleep(0.05)
|
||||
|
||||
result = fresh_seeder.mark_missing_outside_prefixes()
|
||||
assert result == 0
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_mark_missing_returns_zero_when_dependencies_unavailable(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
):
|
||||
with patch("app.assets.seeder.dependencies_available", return_value=False):
|
||||
result = fresh_seeder.mark_missing_outside_prefixes()
|
||||
assert result == 0
|
||||
|
||||
def test_prune_first_flag_triggers_mark_missing_before_scan(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
):
|
||||
call_order = []
|
||||
|
||||
def track_mark(prefixes):
|
||||
call_order.append("mark_missing")
|
||||
return 3
|
||||
|
||||
def track_sync(root):
|
||||
call_order.append(f"sync_{root}")
|
||||
return set()
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]),
|
||||
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
|
||||
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), prune_first=True)
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
assert call_order[0] == "mark_missing"
|
||||
assert "sync_models" in call_order
|
||||
Reference in New Issue
Block a user