Compare commits

..

20 Commits

Author SHA1 Message Date
bymyself
44dbb093c3 feat: pass filtering fields to frontend instead of server-side filtering
Remove server-side includeOnDistributions filtering and fileByDistribution
resolution. Instead, forward requiresCustomNodes and includeOnDistributions
in the info object so the frontend handles filtering client-side.

Amp-Thread-ID: https://ampcode.com/threads/T-019c6f43-6212-7308-bea6-bfc35a486cbf
2026-02-20 02:21:55 -08:00
bymyself
6797e20bcf feat: switch SubgraphManager to index.json-driven discovery with distribution filtering
Amp-Thread-ID: https://ampcode.com/threads/T-019c30d2-a605-708d-824f-35e8f3a0c2f5
2026-02-05 20:33:17 -08:00
AustinMroz
2b70ab9ad0 Add a Create List node (#12173) 2026-02-05 01:18:21 -05:00
Comfy Org PR Bot
00efcc6cd0 Bump comfyui-frontend-package to 1.38.13 (#12238) 2026-02-05 01:17:37 -05:00
comfyanonymous
cb459573c8 ComfyUI v0.12.3 2026-02-05 01:13:35 -05:00
comfyanonymous
35183543e0 Add VAE tiled decode node for audio. (#12299) 2026-02-05 01:12:04 -05:00
blepping
a246cc02b2 Improvements to ACE-Steps 1.5 text encoding (#12283) 2026-02-05 00:17:37 -05:00
comfyanonymous
a50c32d63f Disable sage attention on ace step 1.5 (#12297) 2026-02-04 22:15:30 -05:00
comfyanonymous
6125b80979 Add llm sampling options and make reference audio work on ace step 1.5 (#12295) 2026-02-04 21:29:22 -05:00
comfyanonymous
c8fcbd66ee Try to fix ace text encoder slowness on some configs. (#12290) 2026-02-04 19:37:05 -05:00
comfyanonymous
26dd7eb421 Fix ace step nan issue on some hardware/pytorch configs. (#12289) 2026-02-04 18:25:06 -05:00
Alexander Piskun
e77b34dfea add File3DAny output to Load3D node; extend SaveGLB to accept File3DAny as input (#12276)
* add File3DAny output to Load3D node; extend SaveGLB node to accept File3DAny as input

* fix(grammar): capitalize letter
2026-02-04 11:35:38 -08:00
rattus
ef73070ea4 mp: Fix checkpoint saving (#12268)
Fix regression in the recent model saving refactor. Pass the non unet
pieces down the layers so that checkpoints are complete.
2026-02-04 02:08:45 -05:00
rattus
d30c609f5a utils: safetensors: dont slice data on torch level (#12266)
Torch has alignment enforcement when viewing with data type changes
but only relative to itself. Do all tensor constructions straight
off the memory-view individually so pytorch doesnt see an alignment
problem.

The is needed for handling misaligned safetensors weights, which are
reasonably common in third party models.

This limits usage of this safetensors loader to GPU compute only
as CPUs kernnel are very likely to bus error. But it works for
dynamic_vram, where we really dont want to take a deep copy and we
always use GPU copy_ which disentangles the misalignment.
2026-02-04 01:48:47 -05:00
comfyanonymous
5087f1d497 ComfyUI v0.12.2 2026-02-04 00:08:59 -05:00
comfyanonymous
a31681564d Fix crash with ace step 1.5 (#12264) 2026-02-04 00:03:21 -05:00
rattus
855849c658 mm: Remove Aimdo exemption for empty_cache (#12260)
Its more important to get the torch caching allocator GC up and running
than supporting the pyt2.7 bug. Switch it on.

Defeature dynamic_vram + pyt2.7.
2026-02-03 21:39:19 -05:00
comfyanonymous
fe2511468d Support the 4B ace step 1.5 lm model. (#12257)
Can be used as an alternative to the 1.7B
2026-02-03 19:01:38 -05:00
comfyanonymous
3be0175166 ComfyUI v0.12.1 2026-02-03 15:01:46 -05:00
comfyanonymous
b8315e66cb Fix tiled vae for ace step 1.5 (#12253) 2026-02-03 14:40:45 -05:00
68 changed files with 3080 additions and 6377 deletions

View File

@@ -1,36 +1,19 @@
import logging
import os
import urllib.parse
import uuid
from typing import Any
import urllib.parse
import os
import contextlib
from aiohttp import web
from pydantic import ValidationError
import folder_paths
import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in, schemas_out
from app.assets.api.schemas_in import (
AssetValidationError,
UploadError,
)
from app.assets.api.upload import parse_multipart_upload
from app.assets.scanner import seed_assets as scanner_seed_assets
from app.assets.services import (
DependencyMissingError,
HashMismatchError,
apply_tags,
asset_exists,
create_from_hash,
delete_asset_reference,
get_asset_detail,
list_assets_page,
list_tags,
remove_tags,
resolve_asset_for_download,
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
from app.assets.scanner import seed_assets
import folder_paths
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@@ -38,78 +21,36 @@ USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
query_dict = {
key: request.query.getall(key)
if len(request.query.getall(key)) > 1
else request.query.get(key)
for key in request.query.keys()
}
return query_dict
# Note to any custom node developers reading this code:
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
def register_assets_system(
app: web.Application, user_manager_instance: user_manager.UserManager
) -> None:
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _build_error_response(
status: int, code: str, message: str, details: dict | None = None
) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _build_error_response(400, code, "Validation failed.", {"errors": ve.json()})
def _validate_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _build_error_response(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if (
algo != "blake3"
or not digest
or any(c for c in digest if c not in "0123456789abcdef")
):
return _build_error_response(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
exists = asset_exists(hash_str)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets_route(request: web.Request) -> web.Response:
async def list_assets(request: web.Request) -> web.Response:
"""
GET request to list assets.
"""
@@ -117,124 +58,66 @@ async def list_assets_route(request: web.Request) -> web.Response:
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
return _validation_error_response("INVALID_QUERY", ve)
sort = _validate_sort_field(q.sort)
order = (
"desc"
if (q.order or "desc").lower() not in {"asc", "desc"}
else q.order.lower()
)
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
payload = manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
)
summaries = [
schemas_out.AssetSummary(
id=item.info.id,
name=item.info.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes)
if item.asset and item.asset.size_bytes
else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.info.created_at,
updated_at=item.info.updated_at,
last_access_time=item.info.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList(
assets=summaries,
total=result.total,
has_more=(q.offset + len(summaries)) < result.total,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset_route(request: web.Request) -> web.Response:
async def get_asset(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
"""
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = get_asset_detail(
result = manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if not result:
return _build_error_response(
404,
"ASSET_NOT_FOUND",
f"AssetInfo {asset_info_id} not found",
{"id": asset_info_id},
)
payload = schemas_out.AssetDetail(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes)
if result.asset and result.asset.size_bytes is not None
else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
)
except ValueError as e:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}
)
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
except Exception:
logging.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
result = resolve_asset_for_download(
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
abs_path = result.abs_path
content_type = result.content_type
filename = result.download_name
except ValueError as ve:
return _build_error_response(404, "ASSET_NOT_FOUND", str(ve))
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie))
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _build_error_response(
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
)
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f"{disposition}; filename=\"{quoted}\"; filename*=UTF-8''{urllib.parse.quote(filename)}"
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
file_size = os.path.getsize(abs_path)
logging.info(
@@ -246,7 +129,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
filename,
)
async def stream_file_chunks():
async def file_sender():
chunk_size = 64 * 1024
with open(abs_path, "rb") as f:
while True:
@@ -256,7 +139,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
yield chunk
return web.Response(
body=stream_file_chunks(),
body=file_sender(),
content_type=content_type,
headers={
"Content-Disposition": cd,
@@ -266,18 +149,16 @@ async def download_asset_content(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash_route(request: web.Request) -> web.Response:
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _build_validation_error_response("INVALID_BODY", ve)
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = create_from_hash(
result = manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
@@ -285,191 +166,228 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
)
payload_out = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes) if result.asset.size_bytes else None,
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
created_new=result.created_new,
)
return web.json_response(payload_out.model_dump(mode="json"), status=201)
def _delete_temp_file_if_exists(path: str | None) -> None:
if path and os.path.exists(path):
try:
os.remove(path)
except Exception:
pass
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists)
except UploadError as e:
return _build_error_response(e.status, e.code, e.message)
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
try:
spec = schemas_in.UploadAssetSpec.model_validate(
{
"tags": parsed.tags_raw,
"name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash,
}
)
except ValidationError as ve:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
)
if spec.tags and spec.tags[0] == "models":
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
):
_delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)
try:
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash(
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
if result is None:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist"
)
_delete_temp_file_if_exists(parsed.tmp_path)
else:
# Otherwise, we must have a temp file path to ingest
if not parsed.tmp_path or not os.path.exists(parsed.tmp_path):
return _build_error_response(
404,
"ASSET_NOT_FOUND",
"Provided hash not found and no file uploaded.",
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
result = upload_from_temp_path(
temp_path=parsed.tmp_path,
name=spec.name,
tags=spec.tags,
user_metadata=spec.user_metadata or {},
client_filename=parsed.file_client_name,
owner_id=owner_id,
expected_hash=spec.hash,
)
except AssetValidationError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, e.code, str(e))
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "BAD_REQUEST", str(e))
except HashMismatchError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "HASH_MISMATCH", str(e))
except DependencyMissingError as e:
_delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(503, "DEPENDENCY_MISSING", e.message)
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
_delete_temp_file_if_exists(parsed.tmp_path)
logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes) if result.asset.size_bytes else None,
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
preview_id=result.info.preview_id,
created_at=result.info.created_at,
last_access_time=result.info.last_access_time,
created_new=result.created_new,
)
status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status)
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset_route(request: web.Request) -> web.Response:
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _build_validation_error_response("INVALID_BODY", ve)
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = update_asset_metadata(
result = manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.AssetUpdated(
id=result.info.id,
name=result.info.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.info.user_metadata or {},
updated_at=result.info.updated_at,
)
except (ValueError, PermissionError) as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset_route(request: web.Request) -> web.Response:
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content_param = request.query.get("delete_content")
delete_content = (
True
if delete_content_param is None
else delete_content_param.lower() not in {"0", "false", "no"}
)
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = delete_asset_reference(
deleted = manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
@@ -480,12 +398,10 @@ async def delete_asset_route(request: web.Request) -> web.Response:
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _build_error_response(
404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found."
)
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@@ -500,17 +416,11 @@ async def get_tags(request: web.Request) -> web.Response:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as e:
return web.json_response(
{
"error": {
"code": "INVALID_QUERY",
"message": "Invalid query parameters",
"details": e.errors(),
}
},
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
status=400,
)
rows, total = list_tags(
result = manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
@@ -518,108 +428,72 @@ async def get_tags(request: web.Request) -> web.Response:
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
tags = [
schemas_out.TagUsage(name=name, count=count, type=tag_type)
for (name, tag_type, count) in rows
]
payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
)
return web.json_response(payload.model_dump(mode="json"))
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
json_payload = await request.json()
data = schemas_in.TagsAdd.model_validate(json_payload)
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _build_error_response(
400,
"INVALID_BODY",
"Invalid JSON body for tags add.",
{"errors": ve.errors()},
)
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = apply_tags(
result = manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.TagsAdd(
added=result.added,
already_present=result.already_present,
total_tags=result.total_tags,
)
except (ValueError, PermissionError) as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
json_payload = await request.json()
data = schemas_in.TagsRemove.model_validate(json_payload)
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _build_error_response(
400,
"INVALID_BODY",
"Invalid JSON body for tags remove.",
{"errors": ve.errors()},
)
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = remove_tags(
result = manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
payload = schemas_out.TagsRemove(
removed=result.removed,
not_present=result.not_present,
total_tags=result.total_tags,
)
except ValueError as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}
)
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200)
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
async def seed_assets(request: web.Request) -> web.Response:
async def seed_assets_endpoint(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output)."""
try:
payload = await request.json()
@@ -629,12 +503,12 @@ async def seed_assets(request: web.Request) -> web.Response:
valid_roots = [r for r in roots if r in ("models", "input", "output")]
if not valid_roots:
return _build_error_response(400, "INVALID_BODY", "No valid roots specified")
return _error_response(400, "INVALID_BODY", "No valid roots specified")
try:
scanner_seed_assets(tuple(valid_roots))
seed_assets(tuple(valid_roots))
except Exception:
logging.exception("scanner_seed_assets failed for roots=%s", valid_roots)
return _build_error_response(500, "INTERNAL", "Seed operation failed")
logging.exception("seed_assets failed for roots=%s", valid_roots)
return _error_response(500, "INTERNAL", "Seed operation failed")
return web.json_response({"seeded": valid_roots}, status=200)

View File

@@ -1,5 +1,4 @@
import json
from dataclasses import dataclass
from typing import Any, Literal
from pydantic import (
@@ -11,61 +10,6 @@ from pydantic import (
model_validator,
)
class UploadError(Exception):
"""Error during upload parsing with HTTP status and code (used in HTTP layer only)."""
def __init__(self, status: int, code: str, message: str):
super().__init__(message)
self.status = status
self.code = code
self.message = message
class AssetValidationError(Exception):
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
def __init__(self, code: str, message: str):
super().__init__(message)
self.code = code
class AssetNotFoundError(Exception):
"""Asset or asset content not found."""
def __init__(self, message: str):
super().__init__(message)
class HashMismatchError(Exception):
"""Uploaded file hash does not match provided hash."""
pass
class DependencyMissingError(Exception):
"""A required dependency is not installed."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
@dataclass
class ParsedUpload:
"""Result of parsing a multipart upload request."""
file_present: bool
file_written: int
file_client_name: str | None
tmp_path: str | None
tags_raw: list[str]
provided_name: str | None
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
@@ -77,9 +21,7 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
)
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@@ -119,7 +61,7 @@ class UpdateAssetBody(BaseModel):
user_metadata: dict[str, Any] | None = None
@model_validator(mode="after")
def _validate_at_least_one_field(self):
def _at_least_one(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
@@ -148,7 +90,7 @@ class CreateFromHashBody(BaseModel):
@field_validator("tags", mode="before")
@classmethod
def _normalize_tags_field(cls, v):
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
@@ -221,7 +163,6 @@ class UploadAssetSpec(BaseModel):
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
@@ -319,7 +260,5 @@ class UploadAssetSpec(BaseModel):
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
raise ValueError("models uploads require a category tag as the second tag")
return self

View File

@@ -19,7 +19,7 @@ class AssetSummary(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
@@ -40,7 +40,7 @@ class AssetUpdated(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
@@ -59,7 +59,7 @@ class AssetDetail(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None

View File

@@ -1,165 +0,0 @@
import os
import uuid
from aiohttp import web
import folder_paths
from app.assets.api.schemas_in import ParsedUpload, UploadError
def normalize_and_validate_hash(s: str) -> str:
"""
Validate and normalize a hash string.
Returns canonical 'blake3:<hex>' or raises UploadError.
"""
s = s.strip().lower()
if not s:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if ":" not in s:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if (
algo != "blake3"
or not digest
or any(c for c in digest if c not in "0123456789abcdef")
):
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
return f"{algo}:{digest}"
async def parse_multipart_upload(
request: web.Request,
check_hash_exists: callable,
) -> ParsedUpload:
"""
Parse a multipart/form-data upload request.
Args:
request: The aiohttp request
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
Returns:
ParsedUpload with parsed fields and temp file path
Raises:
UploadError: On validation or I/O errors
"""
if not (request.content_type or "").lower().startswith("multipart/"):
raise UploadError(
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
)
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
raise UploadError(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
if s:
provided_hash = normalize_and_validate_hash(s)
try:
provided_hash_exists = check_hash_exists(provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
)
continue
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
_delete_temp_file_if_exists(tmp_path)
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
)
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
)
if (
file_present
and file_written == 0
and not (provided_hash and provided_hash_exists)
):
_delete_temp_file_if_exists(tmp_path)
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
return ParsedUpload(
file_present=file_present,
file_written=file_written,
file_client_name=file_client_name,
tmp_path=tmp_path,
tags_raw=tags_raw,
provided_name=provided_name,
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
)
def _delete_temp_file_if_exists(tmp_path: str | None) -> None:
"""Safely remove a temp file if it exists."""
if tmp_path:
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception:
pass

View File

@@ -0,0 +1,204 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
@@ -20,21 +20,19 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import get_utc_now
from app.database.models import Base, to_dict
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list[AssetInfo]] = relationship(
@@ -77,9 +75,7 @@ class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
@@ -89,9 +85,7 @@ class AssetCacheState(Base):
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"
),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
@@ -105,29 +99,15 @@ class AssetCacheState(Base):
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False
)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
@@ -163,9 +143,7 @@ class AssetInfo(Base):
)
__table_args__ = (
UniqueConstraint(
"asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"
),
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
@@ -218,7 +196,7 @@ class AssetInfoTag(Base):
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
@@ -247,7 +225,9 @@ class Tag(Base):
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -0,0 +1,976 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@@ -1,99 +0,0 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
get_asset_by_hash,
upsert_asset,
)
from app.assets.database.queries.asset_info import (
asset_info_exists_for_asset_id,
bulk_insert_asset_infos_ignore_conflicts,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_info_by_id,
get_asset_info_ids_by_ids,
get_or_create_asset_info,
insert_asset_info,
list_asset_infos_page,
set_asset_info_metadata,
set_asset_info_preview,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_timestamps,
update_asset_info_updated_at,
)
from app.assets.database.queries.cache_state import (
CacheStateRow,
bulk_insert_cache_states_ignore_conflicts,
bulk_set_needs_verify,
delete_assets_by_ids,
delete_cache_states_by_ids,
delete_cache_states_outside_prefixes,
delete_orphaned_seed_asset,
get_cache_states_by_paths_and_asset_ids,
get_cache_states_for_prefixes,
get_orphaned_seed_asset_ids,
list_cache_states_by_asset_id,
upsert_cache_state,
)
from app.assets.database.queries.tags import (
AddTagsDict,
RemoveTagsDict,
SetTagsDict,
add_missing_tag_for_asset_id,
add_tags_to_asset_info,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_asset_tags,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_asset_info,
set_asset_info_tags,
)
__all__ = [
"AddTagsDict",
"CacheStateRow",
"RemoveTagsDict",
"SetTagsDict",
"add_missing_tag_for_asset_id",
"add_tags_to_asset_info",
"asset_exists_by_hash",
"asset_info_exists_for_asset_id",
"bulk_insert_asset_infos_ignore_conflicts",
"bulk_insert_assets",
"bulk_insert_cache_states_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_set_needs_verify",
"delete_asset_info_by_id",
"delete_assets_by_ids",
"delete_cache_states_by_ids",
"delete_cache_states_outside_prefixes",
"delete_orphaned_seed_asset",
"ensure_tags_exist",
"fetch_asset_info_and_asset",
"fetch_asset_info_asset_and_tags",
"get_asset_by_hash",
"get_asset_info_by_id",
"get_asset_info_ids_by_ids",
"get_asset_tags",
"get_cache_states_by_paths_and_asset_ids",
"get_cache_states_for_prefixes",
"get_or_create_asset_info",
"get_orphaned_seed_asset_ids",
"insert_asset_info",
"list_asset_infos_page",
"list_cache_states_by_asset_id",
"list_tags_with_usage",
"remove_missing_tag_for_asset_id",
"remove_tags_from_asset_info",
"set_asset_info_metadata",
"set_asset_info_preview",
"set_asset_info_tags",
"update_asset_info_access_time",
"update_asset_info_name",
"update_asset_info_timestamps",
"update_asset_info_updated_at",
"upsert_asset",
"upsert_cache_state",
]

View File

@@ -1,90 +0,0 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries.common import calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows. Each dict should have: id, hash, size_bytes, mime_type, created_at."""
if not rows:
return
ins = sqlite.insert(Asset)
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)

View File

@@ -1,527 +0,0 @@
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from typing import Sequence
import sqlalchemy as sa
from sqlalchemy import delete, exists, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import (
Asset,
AssetInfo,
AssetInfoMeta,
AssetInfoTag,
Tag,
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
def _check_is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return {"key": key, "ordinal": ordinal, "val_num": num}
if isinstance(value, str):
return {"key": key, "ordinal": ordinal, "val_str": value}
return {"key": key, "ordinal": ordinal, "val_json": value}
def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
if value is None:
return [_scalar_to_row(key, 0, None)]
if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)]
if isinstance(value, list):
if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_info_exists_for_asset_id(
session: Session,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_info_by_id(
session: Session,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def insert_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> AssetInfo | None:
"""Insert a new AssetInfo. Returns None if unique constraint violated."""
now = get_utc_now()
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset_id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
except IntegrityError:
return None
def get_or_create_asset_info(
session: Session,
asset_id: str,
owner_id: str,
name: str,
preview_id: str | None = None,
) -> tuple[AssetInfo, bool]:
"""Get existing or create new AssetInfo. Returns (info, created)."""
info = insert_asset_info(
session,
asset_id=asset_id,
owner_id=owner_id,
name=name,
preview_id=preview_id,
)
if info:
return info, True
existing = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset_id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
.unique()
.scalar_one_or_none()
)
if not existing:
raise RuntimeError("Failed to find AssetInfo after insert conflict.")
return existing, False
def update_asset_info_timestamps(
session: Session,
asset_info: AssetInfo,
preview_id: str | None = None,
) -> None:
"""Update timestamps and optionally preview_id on existing AssetInfo."""
now = get_utc_now()
if preview_id and asset_info.preview_id != preview_id:
asset_info.preview_id = preview_id
asset_info.updated_at = now
if asset_info.last_access_time < now:
asset_info.last_access_time = now
session.flush()
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(build_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def update_asset_info_access_time(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or get_utc_now()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts
)
)
session.execute(stmt.values(last_access_time=ts))
def update_asset_info_name(
session: Session,
asset_info_id: str,
name: str,
) -> None:
"""Update the name of an AssetInfo."""
now = get_utc_now()
session.execute(
sa.update(AssetInfo)
.where(AssetInfo.id == asset_info_id)
.values(name=name, updated_at=now)
)
def update_asset_info_updated_at(
session: Session,
asset_info_id: str,
ts: datetime | None = None,
) -> None:
"""Update the updated_at timestamp of an AssetInfo."""
ts = ts or get_utc_now()
session.execute(
sa.update(AssetInfo).where(AssetInfo.id == asset_info_id).values(updated_at=ts)
)
def set_asset_info_metadata(
session: Session,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = get_utc_now()
session.flush()
session.execute(
delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)
)
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_asset_info_by_id(
session: Session,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
build_visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def set_asset_info_preview(
session: Session,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = get_utc_now()
session.flush()
def bulk_insert_asset_infos_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert AssetInfo rows with ON CONFLICT DO NOTHING.
Each dict should have: id, owner_id, name, asset_id, preview_id,
user_metadata, created_at, updated_at, last_access_time
"""
if not rows:
return
ins = sqlite.insert(AssetInfo).on_conflict_do_nothing(
index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]
)
for chunk in iter_chunks(rows, calculate_rows_per_statement(9)):
session.execute(ins, chunk)
def get_asset_info_ids_by_ids(
session: Session,
info_ids: list[str],
) -> set[str]:
"""Query to find which AssetInfo IDs exist in the database."""
if not info_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(info_ids, MAX_BIND_PARAMS):
result = session.execute(select(AssetInfo.id).where(AssetInfo.id.in_(chunk)))
found.update(result.scalars().all())
return found

View File

@@ -1,280 +0,0 @@
import os
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string
class CacheStateRow(NamedTuple):
"""Row from cache state query with joined asset data."""
state_id: int
file_path: str
mtime_ns: int | None
needs_verify: bool
asset_id: str
asset_hash: str | None
size_bytes: int
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
(
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
)
.scalars()
.all()
)
def upsert_cache_state(
session: Session,
asset_id: str,
file_path: str,
mtime_ns: int,
) -> tuple[bool, bool]:
"""Upsert a cache state by file_path. Returns (created, updated)."""
vals = {
"asset_id": asset_id,
"file_path": file_path,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
if created:
return True, False
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == file_path)
.where(
sa.or_(
AssetCacheState.asset_id != asset_id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset_id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0
return False, updated
def delete_cache_states_outside_prefixes(
session: Session, valid_prefixes: list[str]
) -> int:
"""Delete cache states with file_path not matching any of the valid prefixes.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of cache states deleted
"""
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
result = session.execute(sa.delete(AssetCacheState).where(~matches_valid_prefix))
return result.rowcount
def get_orphaned_seed_asset_ids(session: Session) -> list[str]:
"""Get IDs of seed assets (hash is None) with no remaining cache states.
Returns:
List of asset IDs that are orphaned
"""
orphan_subq = (
sa.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
)
return [row[0] for row in session.execute(orphan_subq).all()]
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""Delete assets and their AssetInfos by ID.
Args:
session: Database session
asset_ids: List of asset IDs to delete
Returns:
Number of assets deleted
"""
if not asset_ids:
return 0
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(asset_ids)))
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
def get_cache_states_for_prefixes(
session: Session,
prefixes: list[str],
) -> list[CacheStateRow]:
"""Get all cache states with paths matching any of the given prefixes.
Args:
session: Database session
prefixes: List of absolute directory prefixes to match
Returns:
List of cache state rows with joined asset data, ordered by asset_id, state_id
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
rows = session.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
).all()
return [
CacheStateRow(
state_id=row[0],
file_path=row[1],
mtime_ns=row[2],
needs_verify=row[3],
asset_id=row[4],
asset_hash=row[5],
size_bytes=int(row[6] or 0),
)
for row in rows
]
def bulk_set_needs_verify(session: Session, state_ids: list[int], value: bool) -> int:
"""Set needs_verify flag for multiple cache states.
Returns: Number of rows updated
"""
if not state_ids:
return 0
result = session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(state_ids))
.values(needs_verify=value)
)
return result.rowcount
def delete_cache_states_by_ids(session: Session, state_ids: list[int]) -> int:
"""Delete cache states by their IDs.
Returns: Number of rows deleted
"""
if not state_ids:
return 0
result = session.execute(
sa.delete(AssetCacheState).where(AssetCacheState.id.in_(state_ids))
)
return result.rowcount
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
"""Delete a seed asset (hash is None) and its AssetInfos.
Returns: True if asset was deleted, False if not found
"""
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == asset_id))
asset = session.get(Asset, asset_id)
if asset:
session.delete(asset)
return True
return False
def bulk_insert_cache_states_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: asset_id, file_path, mtime_ns
"""
if not rows:
return
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
index_elements=[AssetCacheState.file_path]
)
for chunk in iter_chunks(rows, calculate_rows_per_statement(3)):
session.execute(ins, chunk)
def get_cache_states_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query cache states to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
paths = list(path_to_asset.keys())
winners: set[str] = set()
for chunk in iter_chunks(paths, MAX_BIND_PARAMS):
result = session.execute(
select(AssetCacheState.file_path).where(
AssetCacheState.file_path.in_(chunk),
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
)
)
winners.update(result.scalars().all())
return winners

View File

@@ -1,37 +0,0 @@
"""Shared utilities for database query modules."""
from typing import Iterable
import sqlalchemy as sa
from app.assets.database.models import AssetInfo
MAX_BIND_PARAMS = 800
def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count."""
return max(1, MAX_BIND_PARAMS // max(1, cols))
def iter_chunks(seq, n: int):
"""Yield successive n-sized chunks from seq."""
for i in range(0, len(seq), n):
yield seq[i : i + n]
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return []
rows_per_stmt = max(1, MAX_BIND_PARAMS // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i : i + rows_per_stmt]
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@@ -1,349 +0,0 @@
from typing import Iterable, Sequence, TypedDict
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
class AddTagsDict(TypedDict):
added: list[str]
already_present: list[str]
total_tags: list[str]
class RemoveTagsDict(TypedDict):
removed: list[str]
not_present: list[str]
total_tags: list[str]
class SetTagsDict(TypedDict):
added: list[str]
removed: list[str]
total: list[str]
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
]
def set_asset_info_tags(
session: Session,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsDict:
desired = normalize_tags(tags)
current = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def add_tags_to_asset_info(
session: Session,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: AssetInfo | None = None,
) -> AddTagsDict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
asset_info_id: str,
tags: Sequence[str],
) -> RemoveTagsDict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(
AssetInfoTag.asset_info_id == asset_info_id
)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def add_missing_tag_for_asset_id(
session: Session,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
)
)
def remove_missing_tag_for_asset_id(
session: Session,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(
sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)
),
AssetInfoTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(build_visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: List of dicts with keys: asset_info_id, tag_name, origin, added_at
meta_rows: List of dicts with keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_tags = sqlite.insert(AssetInfoTag).on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetInfoMeta).on_conflict_do_nothing(
index_elements=[
AssetInfoMeta.asset_info_id,
AssetInfoMeta.key,
AssetInfoMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@@ -0,0 +1,62 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

75
app/assets/hashing.py Normal file
View File

@@ -0,0 +1,75 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -1,36 +1,52 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from typing import Literal, Sequence
from pathlib import Path
from typing import Literal, Any
import folder_paths
def select_best_live_path(states: Sequence) -> str:
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
alive = [
s
for s in states
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
ALLOWED_ROOTS: tuple[Literal["models", "input", "output"], ...] = (
"models",
"input",
"output",
)
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
@@ -38,11 +54,173 @@ def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def get_utc_now() -> datetime:
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
@@ -51,3 +229,84 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

516
app/assets/manager.py Normal file
View File

@@ -0,0 +1,516 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

View File

@@ -1,318 +1,263 @@
import contextlib
import time
import logging
import os
import time
from typing import Literal, TypedDict
import sqlalchemy
import folder_paths
from app.assets.database.queries import (
add_missing_tag_for_asset_id,
bulk_set_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
ensure_tags_exist,
get_cache_states_for_prefixes,
remove_missing_tag_for_asset_id,
)
from app.assets.services.bulk_ingest import (
SeedAssetSpec,
batch_insert_seed_assets,
prune_orphaned_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
)
from app.database.db import create_session, dependencies_available
class _StateInfo(TypedDict):
sid: int
fp: str
exists: bool
fast_ok: bool
needs_verify: bool
class _AssetAccumulator(TypedDict):
hash: str | None
size_db: int
states: list[_StateInfo]
RootType = Literal["models", "input", "output"]
def get_prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def sync_cache_states_with_filesystem(
session,
root: RootType,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Reconcile cache states with filesystem for a root.
- Toggle needs_verify per state using fast mtime/size check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
Args:
session: Database session
root: Root type to scan
collect_existing_paths: If True, return set of surviving file paths
update_missing_tags: If True, update 'missing' tags based on file status
Returns:
Set of surviving absolute paths if collect_existing_paths=True, else None
"""
prefixes = get_prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
rows = get_cache_states_for_prefixes(session, prefixes)
by_asset: dict[str, _AssetAccumulator] = {}
for row in rows:
acc = by_asset.get(row.asset_id)
if acc is None:
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "states": []}
by_asset[row.asset_id] = acc
fast_ok = False
try:
exists = True
fast_ok = verify_file_unchanged(
mtime_db=row.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append(
{
"sid": row.state_id,
"fp": row.file_path,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": row.needs_verify,
}
)
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing:
delete_orphaned_seed_asset(session, aid)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok:
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(session, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
delete_cache_states_by_ids(session, stale_state_ids)
bulk_set_needs_verify(session, to_set_verify, value=True)
bulk_set_needs_verify(session, to_clear_verify, value=False)
return survivors if collect_existing_paths else None
def _sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's cache states with the filesystem.
Returns survivors (existing paths) or empty set on failure.
"""
try:
with create_session() as sess:
survivors = sync_cache_states_with_filesystem(
sess,
root,
collect_existing_paths=True,
update_missing_tags=True,
)
sess.commit()
return survivors or set()
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", root, e)
return set()
def _prune_orphans_safely(prefixes: list[str]) -> int:
"""Prune orphaned assets outside the given prefixes.
Returns count pruned or 0 on failure.
"""
try:
with create_session() as sess:
count = prune_orphaned_assets(sess, prefixes)
sess.commit()
return count
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
return 0
def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots."""
paths: list[str] = []
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
return paths
def _build_asset_specs(
paths: list[str],
existing_paths: set[str],
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count)."""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def _insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created infos."""
if not specs:
return 0
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_infos
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""Scan the given roots and seed the assets into the database."""
"""
Scan the given roots and seed the assets into the database.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
existing_paths: set[str] = set()
for r in roots:
existing_paths.update(_sync_root_safely(r))
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
all_prefixes = [os.path.abspath(p) for r in roots for p in get_prefixes_for_root(r)]
orphans_pruned = _prune_orphans_safely(all_prefixes)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
paths = _collect_paths_for_roots(roots)
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths)
created = _insert_asset_specs(specs, tag_pool)
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None

View File

@@ -1,91 +0,0 @@
from app.assets.services.asset_management import (
asset_exists,
delete_asset_reference,
get_asset_by_hash,
get_asset_detail,
list_assets_page,
resolve_asset_for_download,
set_asset_preview,
update_asset_metadata,
)
from app.assets.services.bulk_ingest import (
BulkInsertResult,
batch_insert_seed_assets,
prune_orphaned_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
get_size_and_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
ingest_file_from_path,
register_existing_asset,
upload_from_temp_path,
)
from app.assets.services.schemas import (
AddTagsResult,
AssetData,
AssetDetailResult,
AssetInfoData,
AssetSummaryData,
DownloadResolutionResult,
IngestResult,
ListAssetsResult,
RegisterAssetResult,
RemoveTagsResult,
SetTagsResult,
TagUsage,
UploadResult,
UserMetadata,
)
from app.assets.services.tagging import (
apply_tags,
list_tags,
remove_tags,
)
__all__ = [
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetInfoData",
"AssetSummaryData",
"BulkInsertResult",
"DependencyMissingError",
"DownloadResolutionResult",
"HashMismatchError",
"IngestResult",
"ListAssetsResult",
"RegisterAssetResult",
"RemoveTagsResult",
"SetTagsResult",
"TagUsage",
"UploadResult",
"UserMetadata",
"apply_tags",
"asset_exists",
"batch_insert_seed_assets",
"create_from_hash",
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"get_mtime_ns",
"get_size_and_mtime_ns",
"ingest_file_from_path",
"list_assets_page",
"list_files_recursively",
"list_tags",
"prune_orphaned_assets",
"register_existing_asset",
"remove_tags",
"resolve_asset_for_download",
"set_asset_preview",
"update_asset_metadata",
"upload_from_temp_path",
"verify_file_unchanged",
]

View File

@@ -1,290 +0,0 @@
import contextlib
import mimetypes
import os
from typing import Sequence
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_asset_info_by_id,
list_asset_infos_page,
list_cache_states_by_asset_id,
set_asset_info_metadata,
set_asset_info_preview,
set_asset_info_tags,
update_asset_info_access_time,
update_asset_info_name,
update_asset_info_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_filename_for_asset
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
ListAssetsResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
from app.database.db import create_session
def get_asset_detail(
asset_info_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
with create_session() as session:
result = fetch_asset_info_asset_and_tags(
session,
asset_info_id=asset_info_id,
owner_id=owner_id,
)
if not result:
return None
info, asset, tags = result
return AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tags,
)
def update_asset_metadata(
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info.owner_id and info.owner_id != owner_id:
raise PermissionError("not owner")
touched = False
if name is not None and name != info.name:
update_asset_info_name(session, asset_info_id=asset_info_id, name=name)
touched = True
computed_filename = compute_filename_for_asset(session, info.asset_id)
new_meta: dict | None = None
if user_metadata is not None:
new_meta = dict(user_metadata)
elif computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
if new_meta is not None:
if computed_filename:
new_meta["filename"] = computed_filename
set_asset_info_metadata(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
update_asset_info_updated_at(session, asset_info_id=asset_info_id)
result = fetch_asset_info_asset_and_tags(
session,
asset_info_id=asset_info_id,
owner_id=owner_id,
)
if not result:
raise RuntimeError("State changed during update")
info, asset, tag_list = result
detail = AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_list,
)
session.commit()
return detail
def delete_asset_reference(
asset_info_id: str,
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
# Orphaned asset - delete it and its files
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [
s.file_path for s in (states or []) if getattr(s, "file_path", None)
]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
# Delete files after commit
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def set_asset_preview(
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
result = fetch_asset_info_asset_and_tags(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not result:
raise RuntimeError("State changed during preview update")
info, asset, tags = result
detail = AssetDetailResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tags,
)
session.commit()
return detail
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
return extract_asset_data(asset)
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> ListAssetsResult:
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
items: list[AssetSummaryData] = []
for info in infos:
items.append(
AssetSummaryData(
info=extract_info_data(info),
asset=extract_asset_data(info.asset),
tags=tag_map.get(info.id, []),
)
)
return ListAssetsResult(items=items, total=total)
def resolve_asset_for_download(
asset_info_id: str,
owner_id: str = "",
) -> DownloadResolutionResult:
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=asset_info_id, owner_id=owner_id
)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(states)
if not abs_path:
raise FileNotFoundError
update_asset_info_access_time(session, asset_info_id=asset_info_id)
session.commit()
ctype = (
asset.mime_type
or mimetypes.guess_type(info.name or abs_path)[0]
or "application/octet-stream"
)
download_name = info.name or os.path.basename(abs_path)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=download_name,
)

View File

@@ -1,203 +0,0 @@
import os
import uuid
from dataclasses import dataclass
from typing import TypedDict
from sqlalchemy.orm import Session
class SeedAssetSpec(TypedDict):
"""Spec for seeding an asset from filesystem."""
abs_path: str
size_bytes: int
mtime_ns: int
info_name: str
tags: list[str]
fname: str
from app.assets.database.queries import (
bulk_insert_asset_infos_ignore_conflicts,
bulk_insert_assets,
bulk_insert_cache_states_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
delete_cache_states_outside_prefixes,
get_asset_info_ids_by_ids,
get_cache_states_by_paths_and_asset_ids,
get_orphaned_seed_asset_ids,
)
from app.assets.helpers import get_utc_now
@dataclass
class BulkInsertResult:
"""Result of bulk asset insertion."""
inserted_infos: int
won_states: int
lost_states: int
def batch_insert_seed_assets(
session: Session,
specs: list[SeedAssetSpec],
owner_id: str = "",
) -> BulkInsertResult:
"""Seed assets from filesystem specs in batch.
Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim cache states with ON CONFLICT DO NOTHING
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert AssetInfo for winners
6. Insert tags and metadata for successfully inserted AssetInfos
Returns:
BulkInsertResult with inserted_infos, won_states, lost_states
"""
if not specs:
return BulkInsertResult(inserted_infos=0, won_states=0, lost_states=0)
now = get_utc_now()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {}
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
bulk_insert_assets(session, asset_rows)
bulk_insert_cache_states_ignore_conflicts(session, state_rows)
winners_by_path = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets:
delete_assets_by_ids(session, lost_assets)
if not winners_by_path:
return BulkInsertResult(
inserted_infos=0,
won_states=0,
lost_states=len(losers_by_path),
)
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
db_info_rows = [
{
"id": row["id"],
"owner_id": row["owner_id"],
"name": row["name"],
"asset_id": row["asset_id"],
"preview_id": row["preview_id"],
"user_metadata": row["user_metadata"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"last_access_time": row["last_access_time"],
}
for row in winner_info_rows
]
bulk_insert_asset_infos_ignore_conflicts(session, db_info_rows)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids = get_asset_info_ids_by_ids(session, all_info_ids)
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append(
{
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
}
)
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows)
return BulkInsertResult(
inserted_infos=len(inserted_info_ids),
won_states=len(winners_by_path),
lost_states=len(losers_by_path),
)
def prune_orphaned_assets(session: Session, valid_prefixes: list[str]) -> int:
"""Prune cache states outside valid prefixes, then delete orphaned seed assets.
Args:
session: Database session
valid_prefixes: List of absolute directory prefixes that are valid
Returns:
Number of orphaned assets deleted
"""
delete_cache_states_outside_prefixes(session, valid_prefixes)
orphan_ids = get_orphaned_seed_asset_ids(session)
return delete_assets_by_ids(session, orphan_ids)

View File

@@ -1,49 +0,0 @@
import os
def get_mtime_ns(stat_result: os.stat_result) -> int:
"""Extract mtime in nanoseconds from a stat result."""
return getattr(
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
)
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
"""Get file size in bytes and mtime in nanoseconds."""
st = os.stat(path, follow_symlinks=follow_symlinks)
return st.st_size, get_mtime_ns(st)
def verify_file_unchanged(
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
"""Check if a file is unchanged based on mtime and size.
Returns True if the file's mtime and size match the database values.
Returns False if mtime_db is None or values don't match.
"""
if mtime_db is None:
return False
actual_mtime_ns = get_mtime_ns(stat_result)
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def list_files_recursively(base_dir: str) -> list[str]:
"""Recursively list all files in a directory."""
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(
base_abs, topdown=True, followlinks=False
):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out

View File

@@ -1,54 +0,0 @@
import asyncio
import os
from typing import IO
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024
def compute_blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def compute_compute_blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = file_obj.tell()
try:
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -1,388 +0,0 @@
import contextlib
import logging
import mimetypes
import os
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.models import Asset, AssetInfo, Tag
from app.assets.database.queries import (
add_tags_to_asset_info,
fetch_asset_info_and_asset,
get_asset_by_hash,
get_asset_tags,
get_or_create_asset_info,
remove_missing_tag_for_asset_id,
set_asset_info_metadata,
set_asset_info_tags,
update_asset_info_timestamps,
upsert_asset,
upsert_cache_state,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_filename_for_asset,
resolve_destination_from_tags,
validate_path_within_base,
)
from app.assets.services.schemas import (
IngestResult,
RegisterAssetResult,
UploadResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
from app.database.db import create_session
def ingest_file_from_path(
abs_path: str,
asset_hash: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: UserMetadata = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
asset_created = False
asset_updated = False
state_created = False
state_updated = False
asset_info_id: str | None = None
with create_session() as session:
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=size_bytes,
mime_type=mime_type,
)
state_created, state_updated = upsert_cache_state(
session,
asset_id=asset.id,
file_path=locator,
mtime_ns=mtime_ns,
)
if info_name:
info, info_created = get_or_create_asset_info(
session,
asset_id=asset.id,
owner_id=owner_id,
name=info_name,
preview_id=preview_id,
)
if info_created:
asset_info_id = info.id
else:
update_asset_info_timestamps(
session, asset_info=info, preview_id=preview_id
)
asset_info_id = info.id
norm = normalize_tags(list(tags))
if norm and asset_info_id:
if require_existing_tags:
_validate_tags_exist(session, norm)
add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
if asset_info_id:
_update_metadata_with_filename(
session,
asset_info_id=asset_info_id,
asset_id=asset.id,
info=info,
user_metadata=user_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
session.commit()
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
state_created=state_created,
state_updated=state_updated,
asset_info_id=asset_info_id,
)
def register_existing_asset(
asset_hash: str,
name: str,
user_metadata: UserMetadata = None,
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> RegisterAssetResult:
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
info, info_created = get_or_create_asset_info(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
preview_id=None,
)
if not info_created:
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = RegisterAssetResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created=False,
)
session.commit()
return result
new_meta = dict(user_metadata or {})
computed_filename = compute_filename_for_asset(session, asset.id)
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
set_asset_info_metadata(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.refresh(info)
result = RegisterAssetResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created=True,
)
session.commit()
return result
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def _update_metadata_with_filename(
session: Session,
asset_info_id: str,
asset_id: str,
info: AssetInfo,
user_metadata: UserMetadata,
) -> None:
computed_filename = compute_filename_for_asset(session, asset_id)
current_meta = info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
set_asset_info_metadata(
session,
asset_info_id=asset_info_id,
user_metadata=new_meta,
)
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback
class HashMismatchError(Exception):
pass
class DependencyMissingError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)
def upload_from_temp_path(
temp_path: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
) -> UploadResult:
try:
digest = hashing.compute_blake3_hash(temp_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_hash and asset_hash != expected_hash.strip().lower():
raise HashMismatchError("Uploaded file hash does not match provided hash.")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _sanitize_filename(name or client_filename, fallback=digest)
result = register_existing_asset(
asset_hash=asset_hash,
name=display_name,
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
info=result.info,
asset=result.asset,
tags=result.tags,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
ingest_result = ingest_file_from_path(
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = ingest_result.asset_info_id
if not info_id:
raise RuntimeError("failed to create asset metadata")
with create_session() as session:
pair = fetch_asset_info_and_asset(
session, asset_info_id=info_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
return UploadResult(
info=extract_info_data(info),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> UploadResult | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
result = register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
info=result.info,
asset=result.asset,
tags=result.tags,
created_new=False,
)

View File

@@ -1,184 +0,0 @@
import os
from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = (
values[0],
values[1],
) # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory()
if root == "input"
else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def validate_path_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _compute_relative(child: str, parent: str) -> str:
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _check_is_within(fp_abs, input_base):
return "input", _compute_relative(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
)
def compute_filename_for_asset(session, asset_id: str) -> str | None:
"""Compute the relative filename for an asset from its best live cache state path."""
from app.assets.database.queries import list_cache_states_by_asset_id
from app.assets.helpers import select_best_live_path
primary_path = select_best_live_path(
list_cache_states_by_asset_id(session, asset_id=asset_id)
)
return compute_relative_filename(primary_path) if primary_path else None
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_asset_category_and_relative_path`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@@ -1,126 +0,0 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetInfo
UserMetadata = dict[str, Any] | None
@dataclass(frozen=True)
class AssetData:
hash: str
size_bytes: int | None
mime_type: str | None
@dataclass(frozen=True)
class AssetInfoData:
id: str
name: str
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None
@dataclass(frozen=True)
class AssetDetailResult:
info: AssetInfoData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class RegisterAssetResult:
info: AssetInfoData
asset: AssetData
tags: list[str]
created: bool
@dataclass(frozen=True)
class IngestResult:
asset_created: bool
asset_updated: bool
state_created: bool
state_updated: bool
asset_info_id: str | None
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
@dataclass(frozen=True)
class AssetSummaryData:
info: AssetInfoData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
@dataclass(frozen=True)
class DownloadResolutionResult:
abs_path: str
content_type: str
download_name: str
@dataclass(frozen=True)
class UploadResult:
info: AssetInfoData
asset: AssetData
tags: list[str]
created_new: bool
def extract_info_data(info: AssetInfo) -> AssetInfoData:
return AssetInfoData(
id=info.id,
name=info.name,
user_metadata=info.user_metadata,
preview_id=info.preview_id,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
def extract_asset_data(asset: Asset | None) -> AssetData | None:
if asset is None:
return None
return AssetData(
hash=asset.hash,
size_bytes=asset.size_bytes,
mime_type=asset.mime_type,
)

View File

@@ -1,89 +0,0 @@
from app.assets.database.queries import (
add_tags_to_asset_info,
get_asset_info_by_id,
list_tags_with_usage,
remove_tags_from_asset_info,
)
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
from app.database.db import create_session
def apply_tags(
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> AddTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return AddTagsResult(
added=data["added"],
already_present=data["already_present"],
total_tags=data["total_tags"],
)
def remove_tags(
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> RemoveTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return RemoveTagsResult(
removed=data["removed"],
not_present=data["not_present"],
total_tags=data["total_tags"],
)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> tuple[list[TagUsage], int]:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from typing import TypedDict
import json
import logging
import os
import folder_paths
import glob
@@ -90,15 +92,58 @@ class SubgraphManager:
return subgraphs_dict
async def get_blueprint_subgraphs(self, force_reload=False):
"""Load subgraphs from the blueprints directory."""
"""Load subgraphs from the blueprints directory using index.json for discovery."""
if not force_reload and self.cached_blueprint_subgraphs is not None:
return self.cached_blueprint_subgraphs
subgraphs_dict: dict[SubgraphEntry] = {}
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
if os.path.exists(blueprints_dir):
index_path = os.path.join(blueprints_dir, "index.json")
if os.path.isfile(index_path):
try:
with open(index_path, "r", encoding="utf-8") as f:
categories = json.load(f)
except (json.JSONDecodeError, OSError) as e:
logging.error("Failed to load blueprint index %s: %s", index_path, e)
categories = []
if not isinstance(categories, list):
logging.error("Blueprint index.json is not a list: %s", index_path)
categories = []
for category in categories:
module_name = category.get("moduleName", "default")
for blueprint in category.get("blueprints", []):
name = blueprint.get("name")
if not name:
logging.warning("Blueprint entry missing 'name' in category '%s', skipping", module_name)
continue
filename = f"{name}.json"
filepath = os.path.realpath(os.path.join(blueprints_dir, filename))
if not filepath.startswith(os.path.realpath(blueprints_dir) + os.sep):
logging.warning("Blueprint path escapes blueprints directory: %s", filepath)
continue
if not os.path.isfile(filepath):
logging.warning("Blueprint file not found: %s", filepath)
continue
entry_id, entry = self._create_entry(filepath, Source.templates, module_name)
info = entry["info"]
include_on = blueprint.get("includeOnDistributions")
if include_on is not None:
info["includeOnDistributions"] = include_on
requires = blueprint.get("requiresCustomNodes")
if requires is not None:
info["requiresCustomNodes"] = requires
subgraphs_dict[entry_id] = entry
elif os.path.exists(blueprints_dir):
logging.warning("No blueprint index.json found at %s, falling back to glob", index_path)
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
if os.path.basename(file) == "index.json":
continue
file = file.replace('\\', '/')
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
subgraphs_dict[entry_id] = entry

View File

@@ -183,7 +183,7 @@ class AceStepAttention(nn.Module):
else:
attn_bias = window_bias
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
attn_output = self.o_proj(attn_output)
return attn_output
@@ -1035,8 +1035,7 @@ class AceStepConditionGenerationModel(nn.Module):
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
else:
assert False
# TODO ?
lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed)
lm_hints = self.detokenizer(lm_hints_5Hz)

View File

@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape

View File

@@ -1548,6 +1548,7 @@ class ACEStep15(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
device = kwargs["device"]
noise = kwargs["noise"]
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
@@ -1571,15 +1572,19 @@ class ACEStep15(BaseModel):
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
pass_audio_codes = True
else:
refer_audio = refer_audio[-1]
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
pass_audio_codes = False
if pass_audio_codes:
audio_codes = kwargs.get("audio_codes", None)
if audio_codes is not None:
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
refer_audio = refer_audio[:, :, :750]
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
audio_codes = kwargs.get("audio_codes", None)
if audio_codes is not None:
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
return out
class Omnigen2(BaseModel):

View File

@@ -1724,11 +1724,9 @@ def soft_empty_cache(force=False):
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())

View File

@@ -1400,7 +1400,7 @@ class ModelPatcher:
continue
key = "diffusion_model." + k
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
return self.model.state_dict_for_saving(unet_state_dict)
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):
self.unpin_all_weights()

View File

@@ -54,6 +54,8 @@ try:
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
else:

View File

@@ -554,6 +554,8 @@ class VAE:
elif "decoder.layers.1.layers.0.beta" in sd:
config = {}
param_key = None
self.upscale_ratio = 2048
self.downscale_ratio = 2048
if "decoder.layers.2.layers.1.weight_v" in sd:
param_key = "decoder.layers.2.layers.1.weight_v"
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
@@ -562,6 +564,8 @@ class VAE:
if sd[param_key].shape[-1] == 12:
config["strides"] = [2, 4, 4, 6, 10]
self.audio_sample_rate = 48000
self.upscale_ratio = 1920
self.downscale_ratio = 1920
self.first_stage_model = AudioOobleckVAE(**config)
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@@ -569,8 +573,6 @@ class VAE:
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
@@ -870,7 +872,7 @@ class VAE:
/ 3.0)
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
@@ -974,7 +976,7 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
@@ -1442,7 +1444,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
elif clip_type == CLIPType.ACE:
clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data))
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
if TEModel.QWEN3_4B in te_models:
model_type = "qwen3_4b"
else:
model_type = "qwen3_2b"
clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel

View File

@@ -1625,8 +1625,16 @@ class ACEStep15(supported_models_base.BASE):
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect))
detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
if "dtype_llama" in detect_2b:
detect = detect_2b
detect["lm_model"] = "qwen3_2b"
elif "dtype_llama" in detect_4b:
detect = detect_4b
detect["lm_model"] = "qwen3_4b"
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]

View File

@@ -3,6 +3,7 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
import yaml
import comfy.utils
@@ -19,6 +20,7 @@ def sample_manual_loop_no_classes(
min_tokens: int = 1,
max_new_tokens: int = 2048,
audio_start_id: int = 151669, # The cutoff ID for audio codes
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
device = model.execution_device
@@ -60,6 +62,7 @@ def sample_manual_loop_no_classes(
remove_logit_value = torch.finfo(cfg_logits.dtype).min
# Only generate audio tokens
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
cfg_logits[:, eos_token_id] = eos_score
@@ -99,9 +102,7 @@ def sample_manual_loop_no_classes(
return output_audio_codes
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
cfg_scale = 2.0
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
positive = [[token for token, _ in inner_list] for inner_list in positive]
negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
@@ -118,34 +119,80 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
if v not in {"unspecified", None}
}
if len(user_metas):
meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip()
else:
meta_yaml = ""
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
use_keys = ("bpm", "duration", "keyscale", "timesignature")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
elif isinstance(duration, (str, int, float)):
user_metas["duration"] = f"{math.ceil(float(duration))} seconds"
else:
raise TypeError("Unexpected type for duration key, must be str, int or float")
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
out = {}
lyrics = kwargs.get("lyrics", "")
bpm = kwargs.get("bpm", 120)
duration = kwargs.get("duration", 120)
keyscale = kwargs.get("keyscale", "C major")
timesignature = kwargs.get("timesignature", 2)
language = kwargs.get("language", "en")
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
generate_audio_codes = kwargs.get("generate_audio_codes", True)
cfg_scale = kwargs.get("cfg_scale", 2.0)
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
duration = math.ceil(duration)
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
kwargs["duration"] = duration
meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
cot_text = self._metas_to_cot(caption = text, **kwargs)
meta_cap = self._metas_to_cap(**kwargs)
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
}
return out
@@ -162,14 +209,34 @@ class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ACE15TEModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}):
def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
super().__init__()
if dtype_llama is None:
dtype_llama = dtype
model = None
self.constant = 0.4375
if lm_model == "qwen3_4b":
model = Qwen3_4B_ACE15
self.constant = 0.5625
elif lm_model == "qwen3_2b":
model = Qwen3_2B_ACE15
self.lm_model = lm_model
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options)
if model is not None:
setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
self.dtypes = set([dtype, dtype_llama])
def encode_token_weights(self, token_weight_pairs):
@@ -181,18 +248,26 @@ class ACE15TEModel(torch.nn.Module):
self.qwen3_06b.set_clip_options({"layer": [0]})
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
lm_metadata = token_weight_pairs["lm_metadata"]
audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
out["audio_codes"] = [audio_codes]
return base_out, None, out
def set_clip_options(self, options):
self.qwen3_06b.set_clip_options(options)
self.qwen3_2b.set_clip_options(options)
lm_model = getattr(self, self.lm_model, None)
if lm_model is not None:
lm_model.set_clip_options(options)
def reset_clip_options(self):
self.qwen3_06b.reset_clip_options()
self.qwen3_2b.reset_clip_options()
lm_model = getattr(self, self.lm_model, None)
if lm_model is not None:
lm_model.reset_clip_options()
def load_sd(self, sd):
if "model.layers.0.post_attention_layernorm.weight" in sd:
@@ -200,11 +275,11 @@ class ACE15TEModel(torch.nn.Module):
if shape[0] == 1024:
return self.qwen3_06b.load_sd(sd)
else:
return self.qwen3_2b.load_sd(sd)
return getattr(self, self.lm_model).load_sd(sd)
def memory_estimation_function(self, token_weight_pairs, device=None):
lm_metadata = token_weight_pairs["lm_metadata"]
constant = 0.4375
constant = self.constant
if comfy.model_management.should_use_bf16(device):
constant *= 0.5
@@ -213,11 +288,11 @@ class ACE15TEModel(torch.nn.Module):
num_tokens += lm_metadata['min_tokens']
return num_tokens * constant * 1024 * 1024
def te(dtype_llama=None, llama_quantization_metadata=None):
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
class ACE15TEModel_(ACE15TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options)
super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
return ACE15TEModel_

View File

@@ -150,6 +150,29 @@ class Qwen3_2B_ACE15_lm_Config:
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_4B_ACE15_lm_Config:
vocab_size: int = 217204
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
@@ -628,10 +651,10 @@ class Llama2_(nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
if seq_len > 1:
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
if mask is not None:
mask += causal_mask
else:
@@ -739,6 +762,21 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
class Llama2(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
@@ -767,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, torch.nn.Module):
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
@@ -776,7 +814,7 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06B_ACE15_Config(**config_dict)
@@ -785,7 +823,7 @@ class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
@@ -794,22 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def logits(self, x):
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
class Qwen3_4B(BaseLlama, torch.nn.Module):
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -818,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, torch.nn.Module):
class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4B_ACE15_lm_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)

View File

@@ -82,14 +82,12 @@ _TYPES = {
def load_safetensors(ckpt):
f = open(ckpt, "rb")
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
mv = mv[8 + header_size:]
sd = {}
for name, info in header.items():
@@ -97,7 +95,13 @@ def load_safetensors(ckpt):
continue
start, end = info["data_offsets"]
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
if start == end:
sd[name] = torch.empty(info["shape"], dtype =_TYPES[info["dtype"]])
else:
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
return sd, header.get("__metadata__", {}),

View File

@@ -44,13 +44,18 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
@@ -100,14 +105,15 @@ class EmptyAceStep15LatentAudio(io.ComfyNode):
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "type": "audio"})
class ReferenceTimbreAudio(io.ComfyNode):
class ReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ReferenceTimbreAudio",
display_name="Reference Audio",
category="advanced/conditioning/audio",
is_experimental=True,
description="This node sets the reference audio for timbre (for ace step 1.5)",
description="This node sets the reference audio for ace step 1.5",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
@@ -131,7 +137,7 @@ class AceExtension(ComfyExtension):
EmptyAceStepLatentAudio,
TextEncodeAceStepAudio15,
EmptyAceStep15LatentAudio,
ReferenceTimbreAudio,
ReferenceAudio,
]
async def comfy_entrypoint() -> AceExtension:

View File

@@ -94,6 +94,19 @@ class VAEEncodeAudio(IO.ComfyNode):
encode = execute # TODO: remove
def vae_decode_audio(vae, samples, tile=None, overlap=None):
if tile is not None:
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
else:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
class VAEDecodeAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -111,16 +124,33 @@ class VAEDecodeAudio(IO.ComfyNode):
@classmethod
def execute(cls, vae, samples) -> IO.NodeOutput:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
return IO.NodeOutput(vae_decode_audio(vae, samples))
decode = execute # TODO: remove
class VAEDecodeAudioTiled(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeAudioTiled",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio (Tiled)",
category="latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
],
outputs=[IO.Audio.Output()],
)
@classmethod
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
class SaveAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -675,6 +705,7 @@ class AudioExtension(ComfyExtension):
EmptyLatentAudio,
VAEEncodeAudio,
VAEDecodeAudio,
VAEDecodeAudioTiled,
SaveAudio,
SaveAudioMP3,
SaveAudioOpus,

View File

@@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
display_name="Save 3D Model",
search_aliases=["export 3d model", "save mesh"],
category="3d",
is_output_node=True,
@@ -626,8 +627,14 @@ class SaveGLB(IO.ComfyNode):
IO.Mesh.Input("mesh"),
types=[
IO.File3DGLB,
IO.File3DGLTF,
IO.File3DOBJ,
IO.File3DFBX,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DAny,
],
tooltip="Mesh or GLB file to save",
tooltip="Mesh or 3D file to save",
),
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
],
@@ -649,7 +656,8 @@ class SaveGLB(IO.ComfyNode):
if isinstance(mesh, Types.File3D):
# Handle File3D input - save BytesIO data to output folder
f = f"{filename}_{counter:05}_.glb"
ext = mesh.format or "glb"
f = f"{filename}_{counter:05}_.{ext}"
mesh.save_to(os.path.join(full_output_folder, f))
results.append({
"filename": f,

View File

@@ -45,6 +45,7 @@ class Load3D(IO.ComfyNode):
IO.Image.Output(display_name="normal"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Video.Output(display_name="recording_video"),
IO.File3DAny.Output(display_name="model_3d"),
],
)
@@ -66,7 +67,8 @@ class Load3D(IO.ComfyNode):
video = InputImpl.VideoFromFile(recording_video_path)
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
process = execute # TODO: remove

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CreateList(io.ComfyNode):
@classmethod
def define_schema(cls):
template_matchtype = io.MatchType.Template("type")
template_autogrow = io.Autogrow.TemplatePrefix(
input=io.MatchType.Input("input", template=template_matchtype),
prefix="input",
)
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="logic",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
outputs=[
io.MatchType.Output(
template=template_matchtype,
is_output_list=True,
display_name="list",
),
],
)
@classmethod
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
output_list = []
for input in inputs.values():
output_list += input
return io.NodeOutput(output_list)
class ToolkitExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CreateList,
]
async def comfy_entrypoint() -> ToolkitExtension:
return ToolkitExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.12.0"
__version__ = "0.12.3"

View File

@@ -192,7 +192,10 @@ import comfy_aimdo.control
import comfy_aimdo.torch
if enables_dynamic_vram():
if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug()
elif args.verbose == 'CRITICAL':
@@ -208,7 +211,7 @@ if enables_dynamic_vram():
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
logging.info("DynamicVRAM support detected and enabled")
else:
logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None

View File

@@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes():
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_lora_debug.py",
"nodes_color.py"
"nodes_color.py",
"nodes_toolkit.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.12.0"
version = "0.12.3"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.37.11
comfyui-frontend-package==1.38.13
comfyui-workflow-templates==0.8.31
comfyui-embedded-docs==0.4.0
torch

View File

@@ -259,3 +259,13 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
for aid in ids:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@@ -1,14 +0,0 @@
"""Helper functions for assets integration tests."""
import time
import requests
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@@ -1,20 +0,0 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def session():
"""In-memory SQLite session for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
with Session(engine) as sess:
yield sess
@pytest.fixture(autouse=True)
def autoclean_unit_test_assets():
"""Override parent autouse fixture - query tests don't need server cleanup."""
yield

View File

@@ -1,142 +0,0 @@
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
)
class TestAssetExistsByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,expected",
[
(None, "nonexistent", False), # No asset exists
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
(None, "", False), # Null hash in DB doesn't match empty string
],
ids=["nonexistent", "existing", "null_hash_no_match"],
)
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
if setup_hash is not None or setup_hash is None:
# Create asset with given hash (including None for null hash test)
if setup_hash is not None or query_hash == "":
asset = Asset(hash=setup_hash, size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
class TestGetAssetByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,should_find",
[
(None, "nonexistent", False),
("blake3:def456", "blake3:def456", True),
],
ids=["nonexistent", "existing"],
)
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
if setup_hash is not None:
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
session.add(asset)
session.commit()
result = get_asset_by_hash(session, asset_hash=query_hash)
if should_find:
assert result is not None
assert result.size_bytes == 200
assert result.mime_type == "image/png"
else:
assert result is None
class TestUpsertAsset:
@pytest.mark.parametrize(
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
[
# New asset creation
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
# Existing asset, same values - no update
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
# Existing asset with size 0, update with new values
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
# Existing asset, second call with size 0 - no update
(1000, None, 0, None, False, False, 1000, None),
],
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
)
def test_upsert_scenarios(
self,
session: Session,
first_size,
first_mime,
second_size,
second_mime,
expect_created,
expect_updated,
final_size,
final_mime,
):
asset_hash = f"blake3:test_{first_size}_{second_size}"
# First upsert (if first_size is not None, we're testing the second call)
if first_size is not None:
upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=first_size,
mime_type=first_mime,
)
session.commit()
# The upsert call we're testing
asset, created, updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=second_size,
mime_type=second_mime,
)
session.commit()
assert created is expect_created
assert updated is expect_updated
assert asset.size_bytes == final_size
assert asset.mime_type == final_mime
class TestBulkInsertAssets:
def test_inserts_multiple_assets(self, session: Session):
rows = [
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": None},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": None},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": None},
]
bulk_insert_assets(session, rows)
session.commit()
assets = session.query(Asset).all()
assert len(assets) == 3
hashes = {a.hash for a in assets}
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
def test_empty_list_is_noop(self, session: Session):
bulk_insert_assets(session, [])
session.commit()
assert session.query(Asset).count() == 0
def test_handles_large_batch(self, session: Session):
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
rows = [
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": None}
for i in range(200)
]
bulk_insert_assets(session, rows)
session.commit()
assert session.query(Asset).count() == 200

View File

@@ -1,511 +0,0 @@
import time
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
from app.assets.database.queries import (
asset_info_exists_for_asset_id,
get_asset_info_by_id,
insert_asset_info,
get_or_create_asset_info,
update_asset_info_timestamps,
list_asset_infos_page,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
update_asset_info_access_time,
set_asset_info_metadata,
delete_asset_info_by_id,
set_asset_info_preview,
bulk_insert_asset_infos_ignore_conflicts,
get_asset_info_ids_by_ids,
ensure_tags_exist,
add_tags_to_asset_info,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestAssetInfoExistsForAssetId:
def test_returns_false_when_no_info(self, session: Session):
asset = _make_asset(session, "hash1")
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is False
def test_returns_true_when_info_exists(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset)
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is True
class TestGetAssetInfoById:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_asset_info_by_id(session, asset_info_id="nonexistent") is None
def test_returns_info(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="myfile.txt")
result = get_asset_info_by_id(session, asset_info_id=info.id)
assert result is not None
assert result.name == "myfile.txt"
class TestListAssetInfosPage:
def test_empty_db(self, session: Session):
infos, tag_map, total = list_asset_infos_page(session)
assert infos == []
assert tag_map == {}
assert total == 0
def test_returns_infos_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
session.commit()
infos, tag_map, total = list_asset_infos_page(session)
assert len(infos) == 1
assert infos[0].id == info.id
assert set(tag_map[info.id]) == {"alpha", "beta"}
assert total == 1
def test_name_contains_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="model_v1.safetensors")
_make_asset_info(session, asset, name="config.json")
session.commit()
infos, _, total = list_asset_infos_page(session, name_contains="model")
assert total == 1
assert infos[0].name == "model_v1.safetensors"
def test_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="public", owner_id="")
_make_asset_info(session, asset, name="private", owner_id="user1")
session.commit()
# Empty owner sees only public
infos, _, total = list_asset_infos_page(session, owner_id="")
assert total == 1
assert infos[0].name == "public"
# Owner sees both
infos, _, total = list_asset_infos_page(session, owner_id="user1")
assert total == 2
def test_include_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="tagged")
_make_asset_info(session, asset, name="untagged")
ensure_tags_exist(session, ["wanted"])
add_tags_to_asset_info(session, asset_info_id=info1.id, tags=["wanted"])
session.commit()
infos, _, total = list_asset_infos_page(session, include_tags=["wanted"])
assert total == 1
assert infos[0].name == "tagged"
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="keep")
info_exclude = _make_asset_info(session, asset, name="exclude")
ensure_tags_exist(session, ["bad"])
add_tags_to_asset_info(session, asset_info_id=info_exclude.id, tags=["bad"])
session.commit()
infos, _, total = list_asset_infos_page(session, exclude_tags=["bad"])
assert total == 1
assert infos[0].name == "keep"
def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500)
_make_asset_info(session, asset, name="small")
_make_asset_info(session, asset2, name="large")
session.commit()
infos, _, _ = list_asset_infos_page(session, sort="size", order="desc")
assert infos[0].name == "large"
infos, _, _ = list_asset_infos_page(session, sort="name", order="asc")
assert infos[0].name == "large"
class TestFetchAssetInfoAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_asset_and_tags(session, "nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["tag1"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["tag1"])
session.commit()
result = fetch_asset_info_asset_and_tags(session, info.id)
assert result is not None
ret_info, ret_asset, ret_tags = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
assert ret_tags == ["tag1"]
class TestFetchAssetInfoAndAsset:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_and_asset(session, asset_info_id="nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = fetch_asset_info_and_asset(session, asset_info_id=info.id)
assert result is not None
ret_info, ret_asset = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
class TestUpdateAssetInfoAccessTime:
def test_updates_last_access_time(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
original_time = info.last_access_time
session.commit()
import time
time.sleep(0.01)
update_asset_info_access_time(session, asset_info_id=info.id)
session.commit()
session.refresh(info)
assert info.last_access_time > original_time
class TestDeleteAssetInfoById:
def test_deletes_existing(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="")
assert result is True
assert get_asset_info_by_id(session, asset_info_id=info.id) is None
def test_returns_false_for_nonexistent(self, session: Session):
result = delete_asset_info_by_id(session, asset_info_id="nonexistent", owner_id="")
assert result is False
def test_respects_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="user2")
assert result is False
assert get_asset_info_by_id(session, asset_info_id=info.id) is not None
class TestSetAssetInfoPreview:
def test_sets_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=preview_asset.id)
session.commit()
session.refresh(info)
assert info.preview_id == preview_asset.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
info.preview_id = preview_asset.id
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=None)
session.commit()
session.refresh(info)
assert info.preview_id is None
def test_raises_for_nonexistent_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_asset_info_preview(session, asset_info_id="nonexistent", preview_asset_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
with pytest.raises(ValueError, match="Preview Asset"):
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id="nonexistent")
class TestInsertAssetInfo:
def test_creates_new_info(self, session: Session):
asset = _make_asset(session, "hash1")
info = insert_asset_info(
session, asset_id=asset.id, owner_id="user1", name="test.bin"
)
session.commit()
assert info is not None
assert info.name == "test.bin"
assert info.owner_id == "user1"
def test_returns_none_on_conflict(self, session: Session):
asset = _make_asset(session, "hash1")
insert_asset_info(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
session.commit()
# Attempt duplicate with same (asset_id, owner_id, name)
result = insert_asset_info(
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
)
assert result is None
class TestGetOrCreateAssetInfo:
def test_creates_new_info(self, session: Session):
asset = _make_asset(session, "hash1")
info, created = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="new.bin"
)
session.commit()
assert created is True
assert info.name == "new.bin"
def test_returns_existing_info(self, session: Session):
asset = _make_asset(session, "hash1")
info1, created1 = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
info2, created2 = get_or_create_asset_info(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
assert created1 is True
assert created2 is False
assert info1.id == info2.id
class TestUpdateAssetInfoTimestamps:
def test_updates_timestamps(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
original_updated_at = info.updated_at
session.commit()
time.sleep(0.01)
update_asset_info_timestamps(session, info)
session.commit()
session.refresh(info)
assert info.updated_at > original_updated_at
def test_updates_preview_id(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
session.commit()
update_asset_info_timestamps(session, info, preview_id=preview_asset.id)
session.commit()
session.refresh(info)
assert info.preview_id == preview_asset.id
class TestSetAssetInfoMetadata:
def test_sets_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"key": "value"}
)
session.commit()
session.refresh(info)
assert info.user_metadata == {"key": "value"}
# Check metadata table
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "key"
assert meta[0].val_str == "value"
def test_replaces_existing_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"old": "data"}
)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"new": "data"}
)
session.commit()
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "new"
def test_clears_metadata_with_empty_dict(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={"key": "value"}
)
session.commit()
set_asset_info_metadata(
session, asset_info_id=info.id, user_metadata={}
)
session.commit()
session.refresh(info)
assert info.user_metadata == {}
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 0
def test_raises_for_nonexistent(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_asset_info_metadata(
session, asset_info_id="nonexistent", user_metadata={"key": "value"}
)
class TestBulkInsertAssetInfosIgnoreConflicts:
def test_inserts_multiple_infos(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk1.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk2.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_asset_infos_ignore_conflicts(session, rows)
session.commit()
infos = session.query(AssetInfo).all()
assert len(infos) == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="existing.bin", owner_id="")
session.commit()
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "existing.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "new.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_asset_infos_ignore_conflicts(session, rows)
session.commit()
infos = session.query(AssetInfo).all()
assert len(infos) == 2 # existing + new, not 3
def test_empty_list_is_noop(self, session: Session):
bulk_insert_asset_infos_ignore_conflicts(session, [])
assert session.query(AssetInfo).count() == 0
class TestGetAssetInfoIdsByIds:
def test_returns_existing_ids(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="a.bin")
info2 = _make_asset_info(session, asset, name="b.bin")
session.commit()
found = get_asset_info_ids_by_ids(session, [info1.id, info2.id, "nonexistent"])
assert found == {info1.id, info2.id}
def test_empty_list_returns_empty(self, session: Session):
found = get_asset_info_ids_by_ids(session, [])
assert found == set()

View File

@@ -1,416 +0,0 @@
"""Tests for cache_state query functions."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries import (
list_cache_states_by_asset_id,
upsert_cache_state,
delete_cache_states_outside_prefixes,
get_orphaned_seed_asset_ids,
delete_assets_by_ids,
get_cache_states_for_prefixes,
bulk_set_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
bulk_insert_cache_states_ignore_conflicts,
get_cache_states_by_paths_and_asset_ids,
)
from app.assets.helpers import select_best_live_path, get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size)
session.add(asset)
session.flush()
return asset
def _make_cache_state(
session: Session,
asset: Asset,
file_path: str,
mtime_ns: int | None = None,
needs_verify: bool = False,
) -> AssetCacheState:
state = AssetCacheState(
asset_id=asset.id,
file_path=file_path,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
)
session.add(state)
session.flush()
return state
class TestListCacheStatesByAssetId:
def test_returns_empty_for_no_states(self, session: Session):
asset = _make_asset(session, "hash1")
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
assert list(states) == []
def test_returns_states_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/path/a.bin")
_make_cache_state(session, asset, "/path/b.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
paths = [s.file_path for s in states]
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
def test_does_not_return_other_assets_states(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path/asset1.bin")
_make_cache_state(session, asset2, "/path/asset2.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset1.id)
paths = [s.file_path for s in states]
assert paths == ["/path/asset1.bin"]
class TestSelectBestLivePath:
def test_returns_empty_for_empty_list(self):
result = select_best_live_path([])
assert result == ""
def test_returns_empty_when_no_files_exist(self, session: Session):
asset = _make_asset(session, "hash1")
state = _make_cache_state(session, asset, "/nonexistent/path.bin")
session.commit()
result = select_best_live_path([state])
assert result == ""
def test_prefers_verified_path(self, session: Session, tmp_path):
"""needs_verify=False should be preferred."""
asset = _make_asset(session, "hash1")
verified_file = tmp_path / "verified.bin"
verified_file.write_bytes(b"data")
unverified_file = tmp_path / "unverified.bin"
unverified_file.write_bytes(b"data")
state_verified = _make_cache_state(
session, asset, str(verified_file), needs_verify=False
)
state_unverified = _make_cache_state(
session, asset, str(unverified_file), needs_verify=True
)
session.commit()
states = [state_unverified, state_verified]
result = select_best_live_path(states)
assert result == str(verified_file)
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
"""If all states need verification, return first existing path."""
asset = _make_asset(session, "hash1")
existing_file = tmp_path / "exists.bin"
existing_file.write_bytes(b"data")
state = _make_cache_state(session, asset, str(existing_file), needs_verify=True)
session.commit()
result = select_best_live_path([state])
assert result == str(existing_file)
class TestSelectBestLivePathWithMocking:
def test_handles_missing_file_path_attr(self):
"""Gracefully handle states with None file_path."""
class MockState:
file_path = None
needs_verify = False
result = select_best_live_path([MockState()])
assert result == ""
class TestUpsertCacheState:
@pytest.mark.parametrize(
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
[
# New state creation
(None, 12345, True, False, 12345),
# Existing state, same mtime - no update
(100, 100, False, False, 100),
# Existing state, different mtime - update
(100, 200, False, True, 200),
],
ids=["new_state", "existing_no_change", "existing_update_mtime"],
)
def test_upsert_scenarios(
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
):
asset = _make_asset(session, "hash1")
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
# Create initial state if needed
if initial_mtime is not None:
upsert_cache_state(session, asset_id=asset.id, file_path=file_path, mtime_ns=initial_mtime)
session.commit()
# The upsert call we're testing
created, updated = upsert_cache_state(
session, asset_id=asset.id, file_path=file_path, mtime_ns=second_mtime
)
session.commit()
assert created is expect_created
assert updated is expect_updated
state = session.query(AssetCacheState).filter_by(file_path=file_path).one()
assert state.mtime_ns == final_mtime
class TestDeleteCacheStatesOutsidePrefixes:
def test_deletes_states_outside_prefixes(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
valid_dir = tmp_path / "valid"
valid_dir.mkdir()
invalid_dir = tmp_path / "invalid"
invalid_dir.mkdir()
valid_path = str(valid_dir / "file.bin")
invalid_path = str(invalid_dir / "file.bin")
_make_cache_state(session, asset, valid_path)
_make_cache_state(session, asset, invalid_path)
session.commit()
deleted = delete_cache_states_outside_prefixes(session, [str(valid_dir)])
session.commit()
assert deleted == 1
remaining = session.query(AssetCacheState).all()
assert len(remaining) == 1
assert remaining[0].file_path == valid_path
def test_empty_prefixes_deletes_nothing(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/some/path.bin")
session.commit()
deleted = delete_cache_states_outside_prefixes(session, [])
assert deleted == 0
class TestGetOrphanedSeedAssetIds:
def test_returns_orphaned_seed_assets(self, session: Session):
# Seed asset (hash=None) with no cache states
orphan = _make_asset(session, hash_val=None)
# Seed asset with cache state (not orphaned)
with_state = _make_asset(session, hash_val=None)
_make_cache_state(session, with_state, "/has/state.bin")
# Regular asset (hash not None) - should not be returned
_make_asset(session, hash_val="blake3:regular")
session.commit()
orphaned = get_orphaned_seed_asset_ids(session)
assert orphan.id in orphaned
assert with_state.id not in orphaned
class TestDeleteAssetsByIds:
def test_deletes_assets_and_infos(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
info = AssetInfo(
owner_id="", name="test", asset_id=asset.id,
created_at=now, updated_at=now, last_access_time=now
)
session.add(info)
session.commit()
deleted = delete_assets_by_ids(session, [asset.id])
session.commit()
assert deleted == 1
assert session.query(Asset).count() == 0
assert session.query(AssetInfo).count() == 0
def test_empty_list_deletes_nothing(self, session: Session):
_make_asset(session, "hash1")
session.commit()
deleted = delete_assets_by_ids(session, [])
assert deleted == 0
assert session.query(Asset).count() == 1
class TestGetCacheStatesForPrefixes:
def test_returns_states_matching_prefix(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
dir1 = tmp_path / "dir1"
dir1.mkdir()
dir2 = tmp_path / "dir2"
dir2.mkdir()
path1 = str(dir1 / "file.bin")
path2 = str(dir2 / "file.bin")
_make_cache_state(session, asset, path1, mtime_ns=100)
_make_cache_state(session, asset, path2, mtime_ns=200)
session.commit()
rows = get_cache_states_for_prefixes(session, [str(dir1)])
assert len(rows) == 1
assert rows[0].file_path == path1
def test_empty_prefixes_returns_empty(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/some/path.bin")
session.commit()
rows = get_cache_states_for_prefixes(session, [])
assert rows == []
class TestBulkSetNeedsVerify:
def test_sets_needs_verify_flag(self, session: Session):
asset = _make_asset(session, "hash1")
state1 = _make_cache_state(session, asset, "/path1.bin", needs_verify=False)
state2 = _make_cache_state(session, asset, "/path2.bin", needs_verify=False)
session.commit()
updated = bulk_set_needs_verify(session, [state1.id, state2.id], True)
session.commit()
assert updated == 2
session.refresh(state1)
session.refresh(state2)
assert state1.needs_verify is True
assert state2.needs_verify is True
def test_empty_list_updates_nothing(self, session: Session):
updated = bulk_set_needs_verify(session, [], True)
assert updated == 0
class TestDeleteCacheStatesByIds:
def test_deletes_states_by_id(self, session: Session):
asset = _make_asset(session, "hash1")
state1 = _make_cache_state(session, asset, "/path1.bin")
_make_cache_state(session, asset, "/path2.bin")
session.commit()
deleted = delete_cache_states_by_ids(session, [state1.id])
session.commit()
assert deleted == 1
assert session.query(AssetCacheState).count() == 1
def test_empty_list_deletes_nothing(self, session: Session):
deleted = delete_cache_states_by_ids(session, [])
assert deleted == 0
class TestDeleteOrphanedSeedAsset:
@pytest.mark.parametrize(
"create_asset,expected_deleted,expected_count",
[
(True, True, 0), # Existing asset gets deleted
(False, False, 0), # Nonexistent returns False
],
ids=["deletes_existing", "nonexistent_returns_false"],
)
def test_delete_orphaned_seed_asset(
self, session: Session, create_asset, expected_deleted, expected_count
):
asset_id = "nonexistent-id"
if create_asset:
asset = _make_asset(session, hash_val=None)
asset_id = asset.id
now = get_utc_now()
info = AssetInfo(
owner_id="", name="test", asset_id=asset.id,
created_at=now, updated_at=now, last_access_time=now
)
session.add(info)
session.commit()
deleted = delete_orphaned_seed_asset(session, asset_id)
if create_asset:
session.commit()
assert deleted is expected_deleted
assert session.query(Asset).count() == expected_count
class TestBulkInsertCacheStatesIgnoreConflicts:
def test_inserts_multiple_states(self, session: Session):
asset = _make_asset(session, "hash1")
rows = [
{"asset_id": asset.id, "file_path": "/bulk1.bin", "mtime_ns": 100},
{"asset_id": asset.id, "file_path": "/bulk2.bin", "mtime_ns": 200},
]
bulk_insert_cache_states_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetCacheState).count() == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/existing.bin", mtime_ns=100)
session.commit()
rows = [
{"asset_id": asset.id, "file_path": "/existing.bin", "mtime_ns": 999},
{"asset_id": asset.id, "file_path": "/new.bin", "mtime_ns": 200},
]
bulk_insert_cache_states_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetCacheState).count() == 2
existing = session.query(AssetCacheState).filter_by(file_path="/existing.bin").one()
assert existing.mtime_ns == 100 # Original value preserved
def test_empty_list_is_noop(self, session: Session):
bulk_insert_cache_states_ignore_conflicts(session, [])
assert session.query(AssetCacheState).count() == 0
class TestGetCacheStatesByPathsAndAssetIds:
def test_returns_matching_paths(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path1.bin")
_make_cache_state(session, asset2, "/path2.bin")
session.commit()
path_to_asset = {
"/path1.bin": asset1.id,
"/path2.bin": asset2.id,
}
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
assert winners == {"/path1.bin", "/path2.bin"}
def test_excludes_non_matching_asset_ids(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path1.bin")
session.commit()
# Path exists but with different asset_id
path_to_asset = {"/path1.bin": asset2.id}
winners = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
assert winners == set()
def test_empty_dict_returns_empty(self, session: Session):
winners = get_cache_states_by_paths_and_asset_ids(session, {})
assert winners == set()

View File

@@ -1,184 +0,0 @@
"""Tests for metadata filtering logic in asset_info queries."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
from app.assets.database.queries import list_asset_infos_page
from app.assets.database.queries.asset_info import convert_metadata_to_rows
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str,
metadata: dict | None = None,
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id="",
name=name,
asset_id=asset.id,
user_metadata=metadata,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
if metadata:
for key, val in metadata.items():
for row in convert_metadata_to_rows(key, val):
meta_row = AssetInfoMeta(
asset_info_id=info.id,
key=row["key"],
ordinal=row.get("ordinal", 0),
val_str=row.get("val_str"),
val_num=row.get("val_num"),
val_bool=row.get("val_bool"),
val_json=row.get("val_json"),
)
session.add(meta_row)
session.flush()
return info
class TestMetadataFilterByType:
"""Table-driven tests for metadata filtering by different value types."""
@pytest.mark.parametrize(
"match_meta,nomatch_meta,filter_key,filter_val",
[
# String matching
({"category": "models"}, {"category": "images"}, "category", "models"),
# Integer matching
({"epoch": 5}, {"epoch": 10}, "epoch", 5),
# Float matching
({"score": 0.95}, {"score": 0.5}, "score", 0.95),
# Boolean True matching
({"enabled": True}, {"enabled": False}, "enabled", True),
# Boolean False matching
({"enabled": False}, {"enabled": True}, "enabled", False),
],
ids=["string", "int", "float", "bool_true", "bool_false"],
)
def test_filter_matches_correct_value(
self, session: Session, match_meta, nomatch_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", match_meta)
_make_asset_info(session, asset, "nomatch", nomatch_meta)
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 1
assert infos[0].name == "match"
@pytest.mark.parametrize(
"stored_meta,filter_key,filter_val",
[
# String no match
({"category": "models"}, "category", "other"),
# Int no match
({"epoch": 5}, "epoch", 99),
# Float no match
({"score": 0.5}, "score", 0.99),
],
ids=["string_no_match", "int_no_match", "float_no_match"],
)
def test_filter_returns_empty_when_no_match(
self, session: Session, stored_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "item", stored_meta)
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 0
class TestMetadataFilterNull:
"""Tests for null/missing key filtering."""
@pytest.mark.parametrize(
"match_name,match_meta,nomatch_name,nomatch_meta,filter_key",
[
# Null matches missing key
("missing_key", {}, "has_key", {"optional": "value"}, "optional"),
# Null matches explicit null
("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"),
],
ids=["missing_key", "explicit_null"],
)
def test_null_filter_matches(
self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key
):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, match_name, match_meta)
_make_asset_info(session, asset, nomatch_name, nomatch_meta)
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={filter_key: None})
assert total == 1
assert infos[0].name == match_name
class TestMetadataFilterList:
"""Tests for list-based (OR) filtering."""
def test_filter_by_list_matches_any(self, session: Session):
"""List values should match ANY of the values (OR)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "cat_a", {"category": "a"})
_make_asset_info(session, asset, "cat_b", {"category": "b"})
_make_asset_info(session, asset, "cat_c", {"category": "c"})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"category": ["a", "b"]})
assert total == 2
names = {i.name for i in infos}
assert names == {"cat_a", "cat_b"}
class TestMetadataFilterMultipleKeys:
"""Tests for multiple filter keys (AND semantics)."""
def test_multiple_keys_must_all_match(self, session: Session):
"""Multiple keys should ALL match (AND)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", {"type": "model", "version": 2})
_make_asset_info(session, asset, "wrong_type", {"type": "config", "version": 2})
_make_asset_info(session, asset, "wrong_version", {"type": "model", "version": 1})
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={"type": "model", "version": 2}
)
assert total == 1
assert infos[0].name == "match"
class TestMetadataFilterEmptyDict:
"""Tests for empty filter behavior."""
def test_empty_filter_returns_all(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "a", {"key": "val"})
_make_asset_info(session, asset, "b", {})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={})
assert total == 2

View File

@@ -1,366 +0,0 @@
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoTag, AssetInfoMeta, Tag
from app.assets.database.queries import (
ensure_tags_exist,
get_asset_tags,
set_asset_info_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
bulk_insert_tags_and_meta,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestEnsureTagsExist:
def test_creates_new_tags(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_is_idempotent(self, session: Session):
ensure_tags_exist(session, ["alpha"], tag_type="user")
ensure_tags_exist(session, ["alpha"], tag_type="user")
session.commit()
assert session.query(Tag).count() == 1
def test_normalizes_tags(self, session: Session):
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
session.commit()
assert session.query(Tag).count() == 0
def test_tag_type_is_set(self, session: Session):
ensure_tags_exist(session, ["system-tag"], tag_type="system")
session.commit()
tag = session.query(Tag).filter_by(name="system-tag").one()
assert tag.tag_type == "system"
class TestGetAssetTags:
def test_returns_empty_for_no_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
tags = get_asset_tags(session, asset_info_id=info.id)
assert tags == []
def test_returns_tags_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["tag1", "tag2"])
session.add_all([
AssetInfoTag(asset_info_id=info.id, tag_name="tag1", origin="manual", added_at=get_utc_now()),
AssetInfoTag(asset_info_id=info.id, tag_name="tag2", origin="manual", added_at=get_utc_now()),
])
session.flush()
tags = get_asset_tags(session, asset_info_id=info.id)
assert set(tags) == {"tag1", "tag2"}
class TestSetAssetInfoTags:
def test_adds_new_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["added"]) == {"a", "b"}
assert result["removed"] == []
assert set(result["total"]) == {"a", "b"}
def test_removes_old_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a"])
session.commit()
assert result["added"] == []
assert set(result["removed"]) == {"b", "c"}
assert result["total"] == ["a"]
def test_replaces_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["b", "c"])
session.commit()
assert result["added"] == ["c"]
assert result["removed"] == ["a"]
assert set(result["total"]) == {"b", "c"}
class TestAddTagsToAssetInfo:
def test_adds_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert set(result["added"]) == {"x", "y"}
assert result["already_present"] == []
def test_reports_already_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x"])
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert result["added"] == ["y"]
assert result["already_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
add_tags_to_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestRemoveTagsFromAssetInfo:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["removed"]) == {"a", "b"}
assert result["not_present"] == []
assert result["total_tags"] == ["c"]
def test_reports_not_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "x"])
session.commit()
assert result["removed"] == ["a"]
assert result["not_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
remove_tags_from_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestMissingTagFunctions:
def test_add_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" in tags
def test_add_missing_tag_is_idempotent(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="missing").all()
assert len(links) == 1
def test_remove_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" not in tags
class TestListTagsWithUsage:
def test_returns_tags_with_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_exclude_zero_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session, include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags_with_usage(session, order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_owner_visibility(self, session: Session):
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
asset = _make_asset(session, "hash1")
shared_info = _make_asset_info(session, asset, name="shared", owner_id="")
owner_info = _make_asset_info(session, asset, name="owned", owner_id="user1")
add_tags_to_asset_info(session, asset_info_id=shared_info.id, tags=["shared-tag"])
add_tags_to_asset_info(session, asset_info_id=owner_info.id, tags=["owner-tag"])
session.commit()
# Empty owner sees only shared
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 0
# User1 sees both
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 1
class TestBulkInsertTagsAndMeta:
def test_inserts_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_info_id": info.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now},
{"asset_info_id": info.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert set(tags) == {"bulk-tag1", "bulk-tag2"}
def test_inserts_meta(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
meta_rows = [
{
"asset_info_id": info.id,
"key": "meta-key",
"ordinal": 0,
"val_str": "meta-value",
"val_num": None,
"val_bool": None,
"val_json": None,
},
]
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows)
session.commit()
meta = session.query(AssetInfoMeta).filter_by(asset_info_id=info.id).all()
assert len(meta) == 1
assert meta[0].key == "meta-key"
assert meta[0].val_str == "meta-value"
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["existing-tag"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing-tag"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_info_id": info.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
# Should still have only one tag link
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="existing-tag").all()
assert len(links) == 1
# Origin should be original, not overwritten
assert links[0].origin == "manual"
def test_empty_lists_is_noop(self, session: Session):
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[])
assert session.query(AssetInfoTag).count() == 0
assert session.query(AssetInfoMeta).count() == 0

View File

@@ -1 +0,0 @@
# Service layer tests

View File

@@ -1,48 +0,0 @@
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def db_engine():
"""In-memory SQLite engine for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session(db_engine):
"""Session fixture for tests that need direct DB access."""
with Session(db_engine) as sess:
yield sess
@pytest.fixture
def mock_create_session(db_engine):
"""Patch create_session to use our in-memory database."""
from contextlib import contextmanager
from sqlalchemy.orm import Session as SASession
@contextmanager
def _create_session():
with SASession(db_engine) as sess:
yield sess
with patch("app.assets.services.ingest.create_session", _create_session), \
patch("app.assets.services.asset_management.create_session", _create_session), \
patch("app.assets.services.tagging.create_session", _create_session):
yield _create_session
@pytest.fixture
def temp_dir():
"""Temporary directory for file operations."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)

View File

@@ -1,264 +0,0 @@
"""Tests for asset_management services."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
from app.assets.helpers import get_utc_now
from app.assets.services import (
get_asset_detail,
update_asset_metadata,
delete_asset_reference,
set_asset_preview,
)
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestGetAssetDetail:
def test_returns_none_for_nonexistent(self, mock_create_session):
result = get_asset_detail(asset_info_id="nonexistent")
assert result is None
def test_returns_asset_with_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
session.commit()
result = get_asset_detail(asset_info_id=info.id)
assert result is not None
assert result.info.id == info.id
assert result.asset.hash == asset.hash
assert set(result.tags) == {"alpha", "beta"}
def test_respects_owner_visibility(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
# Wrong owner cannot see
result = get_asset_detail(asset_info_id=info.id, owner_id="user2")
assert result is None
# Correct owner can see
result = get_asset_detail(asset_info_id=info.id, owner_id="user1")
assert result is not None
class TestUpdateAssetMetadata:
def test_updates_name(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, name="old_name.bin")
info_id = info.id
session.commit()
update_asset_metadata(
asset_info_id=info_id,
name="new_name.bin",
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.name == "new_name.bin"
def test_updates_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["old"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["old"])
session.commit()
result = update_asset_metadata(
asset_info_id=info.id,
tags=["new1", "new2"],
)
assert set(result.tags) == {"new1", "new2"}
assert "old" not in result.tags
def test_updates_user_metadata(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
info_id = info.id
session.commit()
update_asset_metadata(
asset_info_id=info_id,
user_metadata={"key": "value", "num": 42},
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.user_metadata["key"] == "value"
assert updated_info.user_metadata["num"] == 42
def test_raises_for_nonexistent(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
update_asset_metadata(asset_info_id="nonexistent", name="fail")
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
update_asset_metadata(
asset_info_id=info.id,
name="new",
owner_id="user2",
)
class TestDeleteAssetReference:
def test_deletes_asset_info(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
info_id = info.id
session.commit()
result = delete_asset_reference(
asset_info_id=info_id,
owner_id="",
delete_content_if_orphan=False,
)
assert result is True
assert session.get(AssetInfo, info_id) is None
def test_returns_false_for_nonexistent(self, mock_create_session):
result = delete_asset_reference(
asset_info_id="nonexistent",
owner_id="",
)
assert result is False
def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
info_id = info.id
session.commit()
result = delete_asset_reference(
asset_info_id=info_id,
owner_id="user2",
)
assert result is False
assert session.get(AssetInfo, info_id) is not None
def test_keeps_asset_if_other_infos_exist(self, mock_create_session, session: Session):
asset = _make_asset(session)
info1 = _make_asset_info(session, asset, name="info1")
_make_asset_info(session, asset, name="info2") # Second info keeps asset alive
asset_id = asset.id
session.commit()
delete_asset_reference(
asset_info_id=info1.id,
owner_id="",
delete_content_if_orphan=True,
)
# Asset should still exist
assert session.get(Asset, asset_id) is not None
def test_deletes_orphaned_asset(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
asset_id = asset.id
info_id = info.id
session.commit()
delete_asset_reference(
asset_info_id=info_id,
owner_id="",
delete_content_if_orphan=True,
)
# Both info and asset should be gone
assert session.get(AssetInfo, info_id) is None
assert session.get(Asset, asset_id) is None
class TestSetAssetPreview:
def test_sets_preview(self, mock_create_session, session: Session):
asset = _make_asset(session, hash_val="blake3:main")
preview_asset = _make_asset(session, hash_val="blake3:preview")
info = _make_asset_info(session, asset)
info_id = info.id
preview_id = preview_asset.id
session.commit()
set_asset_preview(
asset_info_id=info_id,
preview_asset_id=preview_id,
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.preview_id == preview_id
def test_clears_preview(self, mock_create_session, session: Session):
asset = _make_asset(session)
preview_asset = _make_asset(session, hash_val="blake3:preview")
info = _make_asset_info(session, asset)
info.preview_id = preview_asset.id
info_id = info.id
session.commit()
set_asset_preview(
asset_info_id=info_id,
preview_asset_id=None,
)
# Verify by re-fetching from DB
session.expire_all()
updated_info = session.get(AssetInfo, info_id)
assert updated_info.preview_id is None
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
set_asset_preview(asset_info_id="nonexistent")
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
set_asset_preview(
asset_info_id=info.id,
preview_asset_id=None,
owner_id="user2",
)

View File

@@ -1,227 +0,0 @@
"""Tests for ingest services."""
from pathlib import Path
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, Tag
from app.assets.database.queries import get_asset_tags
from app.assets.services import ingest_file_from_path, register_existing_asset
class TestIngestFileFromPath:
def test_creates_asset_and_cache_state(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "test_file.bin"
file_path.write_bytes(b"test content")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:abc123",
size_bytes=12,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
)
assert result.asset_created is True
assert result.state_created is True
assert result.asset_info_id is None # no info_name provided
# Verify DB state
assets = session.query(Asset).all()
assert len(assets) == 1
assert assets[0].hash == "blake3:abc123"
states = session.query(AssetCacheState).all()
assert len(states) == 1
assert states[0].file_path == str(file_path)
def test_creates_asset_info_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "model.safetensors"
file_path.write_bytes(b"model data")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:def456",
size_bytes=10,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
info_name="My Model",
owner_id="user1",
)
assert result.asset_created is True
assert result.asset_info_id is not None
info = session.query(AssetInfo).first()
assert info is not None
assert info.name == "My Model"
assert info.owner_id == "user1"
def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "tagged.bin"
file_path.write_bytes(b"data")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:ghi789",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Tagged Asset",
tags=["models", "checkpoints"],
)
assert result.asset_info_id is not None
# Verify tags were created and linked
tags = session.query(Tag).all()
tag_names = {t.name for t in tags}
assert "models" in tag_names
assert "checkpoints" in tag_names
asset_tags = get_asset_tags(session, asset_info_id=result.asset_info_id)
assert set(asset_tags) == {"models", "checkpoints"}
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "dup.bin"
file_path.write_bytes(b"content")
# First ingest
r1 = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:repeat",
size_bytes=7,
mtime_ns=1234567890000000000,
)
assert r1.asset_created is True
# Second ingest with same hash - should update, not create
r2 = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:repeat",
size_bytes=7,
mtime_ns=1234567890000000001, # different mtime
)
assert r2.asset_created is False
assert r2.state_updated is True or r2.state_created is False
# Still only one asset
assets = session.query(Asset).all()
assert len(assets) == 1
def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "with_preview.bin"
file_path.write_bytes(b"data")
# Create a preview asset first
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
session.add(preview_asset)
session.commit()
preview_id = preview_asset.id
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:main",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="With Preview",
preview_id=preview_id,
)
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id == preview_id
def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "bad_preview.bin"
file_path.write_bytes(b"data")
result = ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:badpreview",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Bad Preview",
preview_id="nonexistent-uuid",
)
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id is None
class TestRegisterExistingAsset:
def test_creates_info_for_existing_asset(self, mock_create_session, session: Session):
# Create existing asset
asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png")
session.add(asset)
session.commit()
result = register_existing_asset(
asset_hash="blake3:existing",
name="Registered Asset",
user_metadata={"key": "value"},
tags=["models"],
)
assert result.created is True
assert "models" in result.tags
# Verify by re-fetching from DB
session.expire_all()
infos = session.query(AssetInfo).filter_by(name="Registered Asset").all()
assert len(infos) == 1
def test_returns_existing_info(self, mock_create_session, session: Session):
# Create asset and info
asset = Asset(hash="blake3:withinfo", size_bytes=512)
session.add(asset)
session.flush()
from app.assets.helpers import get_utc_now
info = AssetInfo(
owner_id="",
name="Existing Info",
asset_id=asset.id,
created_at=get_utc_now(),
updated_at=get_utc_now(),
last_access_time=get_utc_now(),
)
session.add(info)
session.flush() # Flush to get the ID
info_id = info.id
session.commit()
result = register_existing_asset(
asset_hash="blake3:withinfo",
name="Existing Info",
owner_id="",
)
assert result.created is False
# Verify only one AssetInfo exists for this name
session.expire_all()
infos = session.query(AssetInfo).filter_by(name="Existing Info").all()
assert len(infos) == 1
assert infos[0].id == info_id
def test_raises_for_nonexistent_hash(self, mock_create_session):
with pytest.raises(ValueError, match="No asset with hash"):
register_existing_asset(
asset_hash="blake3:doesnotexist",
name="Fail",
)
def test_applies_tags_to_new_info(self, mock_create_session, session: Session):
asset = Asset(hash="blake3:tagged", size_bytes=256)
session.add(asset)
session.commit()
result = register_existing_asset(
asset_hash="blake3:tagged",
name="Tagged Info",
tags=["alpha", "beta"],
)
assert result.created is True
assert set(result.tags) == {"alpha", "beta"}

View File

@@ -1,197 +0,0 @@
"""Tests for tagging services."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo
from app.assets.database.queries import ensure_tags_exist, add_tags_to_asset_info
from app.assets.helpers import get_utc_now
from app.assets.services import apply_tags, remove_tags, list_tags
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = get_utc_now()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestApplyTags:
def test_adds_new_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
session.commit()
result = apply_tags(
asset_info_id=info.id,
tags=["alpha", "beta"],
)
assert set(result.added) == {"alpha", "beta"}
assert result.already_present == []
assert set(result.total_tags) == {"alpha", "beta"}
def test_reports_already_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["existing"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["existing"])
session.commit()
result = apply_tags(
asset_info_id=info.id,
tags=["existing", "new"],
)
assert result.added == ["new"]
assert result.already_present == ["existing"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
apply_tags(asset_info_id="nonexistent", tags=["x"])
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
apply_tags(
asset_info_id=info.id,
tags=["new"],
owner_id="user2",
)
class TestRemoveTags:
def test_removes_tags(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["a", "b", "c"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
session.commit()
result = remove_tags(
asset_info_id=info.id,
tags=["a", "b"],
)
assert set(result.removed) == {"a", "b"}
assert result.not_present == []
assert result.total_tags == ["c"]
def test_reports_not_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["present"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["present"])
session.commit()
result = remove_tags(
asset_info_id=info.id,
tags=["present", "absent"],
)
assert result.removed == ["present"]
assert result.not_present == ["absent"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
remove_tags(asset_info_id="nonexistent", tags=["x"])
def test_raises_for_wrong_owner(self, mock_create_session, session: Session):
asset = _make_asset(session)
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
with pytest.raises(PermissionError, match="not owner"):
remove_tags(
asset_info_id=info.id,
tags=["x"],
owner_id="user2",
)
class TestListTags:
def test_returns_tags_with_counts(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session)
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags()
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_excludes_zero_counts(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session)
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags(include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, _ = list_tags(prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags(order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_pagination(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["a", "b", "c", "d", "e"])
session.commit()
rows, total = list_tags(limit=2, offset=1, order="name_asc")
assert total == 5
assert len(rows) == 2
names = [name for name, _, _ in rows]
assert names == ["b", "c"]
def test_clamps_limit(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["a"])
session.commit()
# Service should clamp limit to max 1000
rows, _ = list_tags(limit=2000)
assert len(rows) <= 1000

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import pytest
import requests
from helpers import get_asset_filename, trigger_sync_seed_assets
from conftest import get_asset_filename, trigger_sync_seed_assets

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import pytest
import requests
from helpers import get_asset_filename, trigger_sync_seed_assets
from conftest import get_asset_filename, trigger_sync_seed_assets
def test_create_from_hash_success(
@@ -126,52 +126,42 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api
assert body == b""
@pytest.mark.parametrize(
"method,endpoint_template,payload,expected_status,error_code",
[
# Delete nonexistent asset
("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"),
# Bad hash algorithm in from-hash
(
"post",
"/api/assets/from-hash",
{"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]},
400,
"INVALID_BODY",
),
# Get with bad UUID format
("get", "/api/assets/not-a-uuid", None, 404, None),
# Get content with bad UUID format
("get", "/api/assets/not-a-uuid/content", None, 404, None),
],
ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"],
)
def test_error_responses(
http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code
):
# Replace {uuid} placeholder with a random UUID for delete test
endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4()))
url = f"{api_base}{endpoint}"
if method == "get":
r = http.get(url, timeout=120)
elif method == "post":
r = http.post(url, json=payload, timeout=120)
elif method == "delete":
r = http.delete(url, timeout=120)
assert r.status_code == expected_status
if error_code:
body = r.json()
assert body["error"]["code"] == error_code
def test_create_from_hash_invalid_json(http: requests.Session, api_base: str):
"""Invalid JSON body requires special handling (data= instead of json=)."""
r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
bogus = str(uuid.uuid4())
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_JSON"
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
# Bad hash algorithm
bad = {
"hash": "sha256:" + "0" * 64,
"name": "x.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_BODY"
# Invalid JSON body
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_JSON"
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
# All endpoints should be not found, as we UUID regex directly in the route definition.
bad_id = "not-a-uuid"
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
assert r1.status_code == 404
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
assert r3.status_code == 404
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@@ -6,7 +6,7 @@ from typing import Optional
import pytest
import requests
from helpers import get_asset_filename, trigger_sync_seed_assets
from conftest import get_asset_filename, trigger_sync_seed_assets
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@@ -1,7 +1,6 @@
import time
import uuid
import pytest
import requests
@@ -284,21 +283,30 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse
assert b2["has_more"] is False
@pytest.mark.parametrize(
"params,error_code",
[
({"offset": "-1"}, "INVALID_QUERY"),
({"limit": "abc"}, "INVALID_QUERY"),
({"limit": "0"}, "INVALID_QUERY"),
({"metadata_filter": "{not json"}, "INVALID_QUERY"),
],
ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"],
)
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code):
r = http.get(api_base + "/api/assets", params=params, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == error_code
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
# limit too small
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
b1 = r1.json()
assert r1.status_code == 400
assert b1["error"]["code"] == "INVALID_QUERY"
# bad metadata JSON
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
b2 = r2.json()
assert r2.status_code == 400
assert b2["error"]["code"] == "INVALID_QUERY"
def test_list_assets_name_contains_literal_underscore(

View File

@@ -3,7 +3,7 @@ from pathlib import Path
import pytest
import requests
from helpers import get_asset_filename, trigger_sync_seed_assets
from conftest import get_asset_filename, trigger_sync_seed_assets
@pytest.fixture