Align local asset/tag endpoints with cloud API

Phase 2.1/2.4: Add include_public param to GET /api/assets and GET /api/tags
Phase 2.3: Update PUT /api/assets/{id} with mime_type and preview_id fields, remove separate preview endpoint
Phase 2.2: Add JSON URL upload schema (returns 501 - deferred)
Phase 1.3: Add GET /api/assets/tags/refine endpoint for tag histogram
Phase 1.1/1.2: Add stub endpoints for remote-metadata and download (501)
Phase 4: Add comprehensive tests for all schema changes

Amp-Thread-ID: https://ampcode.com/threads/T-019befd9-1a77-70eb-808d-c83aa0c26515
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski
2026-01-24 04:38:29 -08:00
parent bf995febf2
commit 357b48982f
6 changed files with 382 additions and 54 deletions

View File

@@ -67,10 +67,29 @@ async def list_assets(request: web.Request) -> web.Response:
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
include_public=q.include_public,
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get("/api/assets/remote-metadata")
async def get_remote_asset_metadata(request: web.Request) -> web.Response:
"""
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading.
Status: Not implemented yet.
"""
return _error_response(501, "NOT_IMPLEMENTED", "Remote metadata fetching is not yet supported.")
@ROUTES.post("/api/assets/download")
async def create_asset_download(request: web.Request) -> web.Response:
"""
Initiate background download job for large files from HuggingFace or CivitAI.
Status: Not implemented yet.
"""
return _error_response(501, "NOT_IMPLEMENTED", "Background asset download is not yet supported.")
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
"""
@@ -146,10 +165,22 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
"""Asset upload endpoint supporting multipart/form-data (file upload) or application/json (URL-based)."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
content_type = (request.content_type or "").lower()
if content_type.startswith("application/json"):
try:
payload = await request.json()
schemas_in.UploadAssetFromUrlBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _error_response(501, "NOT_IMPLEMENTED", "URL-based asset upload is not yet supported. Use multipart/form-data file upload.")
if not content_type.startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads or application/json for URL-based uploads.")
reader = await request.multipart()
@@ -340,7 +371,8 @@ async def update_asset(request: web.Request) -> web.Response:
result = manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
tags=body.tags,
mime_type=body.mime_type,
preview_id=body.preview_id,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
@@ -356,34 +388,6 @@ async def update_asset(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
async def set_asset_preview(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.SetPreviewBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.set_asset_preview(
asset_info_id=asset_info_id,
preview_asset_id=body.preview_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (PermissionError, ValueError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
@@ -431,6 +435,7 @@ async def get_tags(request: web.Request) -> web.Response:
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
include_public=query.include_public,
)
return web.json_response(result.model_dump(mode="json"))
@@ -495,6 +500,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get("/api/assets/tags/refine")
async def get_asset_tag_histogram(request: web.Request) -> web.Response:
"""
GET request to get a tag histogram for filtered assets.
"""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = manager.get_tag_histogram(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
include_public=q.include_public,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.post("/api/assets/scan/seed")
async def seed_assets(request: web.Request) -> web.Response:
try:

View File

@@ -27,6 +27,8 @@ class ListAssetsQuery(BaseModel):
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
include_public: bool = True
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
@@ -61,16 +63,28 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel):
name: str | None = None
tags: list[str] | None = None
mime_type: str | None = None
preview_id: str | None = None
user_metadata: dict[str, Any] | None = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.tags is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, tags, user_metadata.")
if self.tags is not None:
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
raise ValueError("Field 'tags' must be an array of strings.")
if self.name is None and self.mime_type is None and self.preview_id is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, mime_type, preview_id, user_metadata.")
return self
@@ -122,6 +136,7 @@ class TagsListQuery(BaseModel):
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
include_public: bool = True
@field_validator("prefix")
@classmethod
@@ -271,10 +286,49 @@ class UploadAssetSpec(BaseModel):
return self
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
class UploadAssetFromUrlBody(BaseModel):
"""JSON body for URL-based asset upload."""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
url: str = Field(..., description="HTTP/HTTPS URL to download the asset from")
name: str = Field(..., max_length=512, description="Display name for the asset")
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
@field_validator("url")
@classmethod
def _validate_url(cls, v):
s = (v or "").strip()
if not s:
raise ValueError("url must not be empty")
if not (s.startswith("http://") or s.startswith("https://")):
raise ValueError("url must start with http:// or https://")
return s
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
return []
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata(cls, v):
if v is None or isinstance(v, dict):
return v or {}
return {}
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
@@ -292,3 +346,45 @@ class SetPreviewBody(BaseModel):
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class TagsRefineQuery(BaseModel):
"""Query parameters for tag histogram/refinement endpoint."""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
include_public: bool = True
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None

View File

@@ -91,3 +91,12 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagHistogramEntry(BaseModel):
name: str
count: int
class TagHistogramResponse(BaseModel):
tags: list[TagHistogramEntry] = Field(default_factory=list)

View File

@@ -683,9 +683,8 @@ def update_asset_info_full(
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
mime_type: str | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
@@ -700,6 +699,11 @@ def update_asset_info_full(
info.name = name
touched = True
if mime_type is not None and info.asset:
if info.asset.mime_type != mime_type:
info.asset.mime_type = mime_type
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
@@ -727,15 +731,6 @@ def update_asset_info_full(
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()

View File

@@ -71,6 +71,7 @@ def list_assets(
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
include_public: bool = True,
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
@@ -289,7 +290,8 @@ def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
@@ -304,12 +306,18 @@ def update_asset(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
mime_type=mime_type,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
if preview_id is not None:
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_id if preview_id else None,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
session.commit()
@@ -490,6 +498,7 @@ def list_tags(
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
include_public: bool = True,
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
@@ -507,3 +516,17 @@ def list_tags(
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
def get_tag_histogram(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
include_public: bool = True,
owner_id: str = "",
) -> schemas_out.TagHistogramResponse:
# TODO: Implement actual histogram query in queries.py
return schemas_out.TagHistogramResponse(tags=[])

View File

@@ -0,0 +1,177 @@
"""Tests for Assets API endpoints (app/assets/api/routes.py)
Tests cover:
- Schema validation for query parameters and request bodies
"""
import pytest
from pydantic import ValidationError
from app.assets.api import schemas_in, schemas_out
class TestListAssetsQuery:
"""Tests for ListAssetsQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.ListAssetsQuery()
assert q.include_tags == []
assert q.exclude_tags == []
assert q.limit == 20
assert q.offset == 0
assert q.sort == "created_at"
assert q.order == "desc"
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.ListAssetsQuery(include_public=False)
assert q.include_public == False
def test_csv_tags_parsing(self):
"""Test comma-separated tags are parsed correctly."""
q = schemas_in.ListAssetsQuery.model_validate({"include_tags": "a,b,c"})
assert q.include_tags == ["a", "b", "c"]
def test_metadata_filter_json_string(self):
"""Test metadata_filter accepts JSON string."""
q = schemas_in.ListAssetsQuery.model_validate({"metadata_filter": '{"key": "value"}'})
assert q.metadata_filter == {"key": "value"}
class TestTagsListQuery:
"""Tests for TagsListQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.TagsListQuery()
assert q.prefix is None
assert q.limit == 100
assert q.offset == 0
assert q.order == "count_desc"
assert q.include_zero == True
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.TagsListQuery(include_public=False)
assert q.include_public == False
class TestUpdateAssetBody:
"""Tests for UpdateAssetBody schema."""
def test_requires_at_least_one_field(self):
"""Test that at least one field is required."""
with pytest.raises(ValidationError):
schemas_in.UpdateAssetBody()
def test_name_only(self):
"""Test updating name only."""
body = schemas_in.UpdateAssetBody(name="new name")
assert body.name == "new name"
assert body.mime_type is None
assert body.preview_id is None
def test_mime_type_only(self):
"""Test updating mime_type only."""
body = schemas_in.UpdateAssetBody(mime_type="image/png")
assert body.mime_type == "image/png"
def test_preview_id_only(self):
"""Test updating preview_id only."""
body = schemas_in.UpdateAssetBody(preview_id="550e8400-e29b-41d4-a716-446655440000")
assert body.preview_id == "550e8400-e29b-41d4-a716-446655440000"
def test_preview_id_invalid_uuid(self):
"""Test invalid UUID for preview_id."""
with pytest.raises(ValidationError):
schemas_in.UpdateAssetBody(preview_id="not-a-uuid")
def test_all_fields(self):
"""Test all fields together."""
body = schemas_in.UpdateAssetBody(
name="test",
mime_type="application/json",
preview_id="550e8400-e29b-41d4-a716-446655440000",
user_metadata={"key": "value"}
)
assert body.name == "test"
assert body.mime_type == "application/json"
class TestUploadAssetFromUrlBody:
"""Tests for UploadAssetFromUrlBody schema (JSON URL upload)."""
def test_valid_url(self):
"""Test valid HTTP URL."""
body = schemas_in.UploadAssetFromUrlBody(
url="https://example.com/model.safetensors",
name="model.safetensors"
)
assert body.url == "https://example.com/model.safetensors"
assert body.name == "model.safetensors"
def test_http_url(self):
"""Test HTTP URL (not just HTTPS)."""
body = schemas_in.UploadAssetFromUrlBody(
url="http://example.com/file.bin",
name="file.bin"
)
assert body.url == "http://example.com/file.bin"
def test_invalid_url_scheme(self):
"""Test invalid URL scheme raises error."""
with pytest.raises(ValidationError):
schemas_in.UploadAssetFromUrlBody(
url="ftp://example.com/file.bin",
name="file.bin"
)
def test_tags_normalized(self):
"""Test tags are normalized to lowercase."""
body = schemas_in.UploadAssetFromUrlBody(
url="https://example.com/model.safetensors",
name="model",
tags=["Models", "LORAS"]
)
assert body.tags == ["models", "loras"]
class TestTagsRefineQuery:
"""Tests for TagsRefineQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.TagsRefineQuery()
assert q.include_tags == []
assert q.exclude_tags == []
assert q.limit == 100
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.TagsRefineQuery(include_public=False)
assert q.include_public == False
class TestTagHistogramResponse:
"""Tests for TagHistogramResponse schema."""
def test_empty_response(self):
"""Test empty response."""
resp = schemas_out.TagHistogramResponse()
assert resp.tags == []
def test_with_entries(self):
"""Test response with entries."""
resp = schemas_out.TagHistogramResponse(
tags=[
schemas_out.TagHistogramEntry(name="models", count=10),
schemas_out.TagHistogramEntry(name="loras", count=5),
]
)
assert len(resp.tags) == 2
assert resp.tags[0].name == "models"
assert resp.tags[0].count == 10