mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 11:09:50 +00:00
Align local asset/tag endpoints with cloud API
Phase 2.1/2.4: Add include_public param to GET /api/assets and GET /api/tags
Phase 2.3: Update PUT /api/assets/{id} with mime_type and preview_id fields, remove separate preview endpoint
Phase 2.2: Add JSON URL upload schema (returns 501 - deferred)
Phase 1.3: Add GET /api/assets/tags/refine endpoint for tag histogram
Phase 1.1/1.2: Add stub endpoints for remote-metadata and download (501)
Phase 4: Add comprehensive tests for all schema changes
Amp-Thread-ID: https://ampcode.com/threads/T-019befd9-1a77-70eb-808d-c83aa0c26515
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -67,10 +67,29 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_public=q.include_public,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/remote-metadata")
|
||||
async def get_remote_asset_metadata(request: web.Request) -> web.Response:
|
||||
"""
|
||||
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading.
|
||||
Status: Not implemented yet.
|
||||
"""
|
||||
return _error_response(501, "NOT_IMPLEMENTED", "Remote metadata fetching is not yet supported.")
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/download")
|
||||
async def create_asset_download(request: web.Request) -> web.Response:
|
||||
"""
|
||||
Initiate background download job for large files from HuggingFace or CivitAI.
|
||||
Status: Not implemented yet.
|
||||
"""
|
||||
return _error_response(501, "NOT_IMPLEMENTED", "Background asset download is not yet supported.")
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -146,10 +165,22 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
"""Asset upload endpoint supporting multipart/form-data (file upload) or application/json (URL-based)."""
|
||||
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
content_type = (request.content_type or "").lower()
|
||||
|
||||
if content_type.startswith("application/json"):
|
||||
try:
|
||||
payload = await request.json()
|
||||
schemas_in.UploadAssetFromUrlBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
return _error_response(501, "NOT_IMPLEMENTED", "URL-based asset upload is not yet supported. Use multipart/form-data file upload.")
|
||||
|
||||
if not content_type.startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads or application/json for URL-based uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
@@ -340,7 +371,8 @@ async def update_asset(request: web.Request) -> web.Response:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
mime_type=body.mime_type,
|
||||
preview_id=body.preview_id,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
@@ -356,34 +388,6 @@ async def update_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
|
||||
async def set_asset_preview(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.SetPreviewBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.set_asset_preview(
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=body.preview_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (PermissionError, ValueError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
@@ -431,6 +435,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_public=query.include_public,
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
@@ -495,6 +500,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/tags/refine")
|
||||
async def get_asset_tag_histogram(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get a tag histogram for filtered assets.
|
||||
"""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = manager.get_tag_histogram(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
include_public=q.include_public,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/seed")
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
try:
|
||||
|
||||
@@ -27,6 +27,8 @@ class ListAssetsQuery(BaseModel):
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
include_public: bool = True
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
@@ -61,16 +63,28 @@ class ListAssetsQuery(BaseModel):
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
tags: list[str] | None = None
|
||||
mime_type: str | None = None
|
||||
preview_id: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.tags is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, tags, user_metadata.")
|
||||
if self.tags is not None:
|
||||
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
||||
raise ValueError("Field 'tags' must be an array of strings.")
|
||||
if self.name is None and self.mime_type is None and self.preview_id is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, mime_type, preview_id, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
@@ -122,6 +136,7 @@ class TagsListQuery(BaseModel):
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
include_public: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
@@ -271,10 +286,49 @@ class UploadAssetSpec(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
class UploadAssetFromUrlBody(BaseModel):
|
||||
"""JSON body for URL-based asset upload."""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
url: str = Field(..., description="HTTP/HTTPS URL to download the asset from")
|
||||
name: str = Field(..., max_length=512, description="Display name for the asset")
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def _validate_url(cls, v):
|
||||
s = (v or "").strip()
|
||||
if not s:
|
||||
raise ValueError("url must not be empty")
|
||||
if not (s.startswith("http://") or s.startswith("https://")):
|
||||
raise ValueError("url must start with http:// or https://")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
return []
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
return {}
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
@@ -292,3 +346,45 @@ class SetPreviewBody(BaseModel):
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[RootType] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class TagsRefineQuery(BaseModel):
|
||||
"""Query parameters for tag histogram/refinement endpoint."""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
limit: conint(ge=1, le=1000) = 100
|
||||
include_public: bool = True
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
@@ -91,3 +91,12 @@ class TagsRemove(BaseModel):
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagHistogramEntry(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
|
||||
|
||||
class TagHistogramResponse(BaseModel):
|
||||
tags: list[TagHistogramEntry] = Field(default_factory=list)
|
||||
|
||||
@@ -683,9 +683,8 @@ def update_asset_info_full(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
mime_type: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
@@ -700,6 +699,11 @@ def update_asset_info_full(
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
if mime_type is not None and info.asset:
|
||||
if info.asset.mime_type != mime_type:
|
||||
info.asset.mime_type = mime_type
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
@@ -727,15 +731,6 @@ def update_asset_info_full(
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@@ -71,6 +71,7 @@ def list_assets(
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
include_public: bool = True,
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
@@ -289,7 +290,8 @@ def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
@@ -304,12 +306,18 @@ def update_asset(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
mime_type=mime_type,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
if preview_id is not None:
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_id if preview_id else None,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
@@ -490,6 +498,7 @@ def list_tags(
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
include_public: bool = True,
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
@@ -507,3 +516,17 @@ def list_tags(
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
|
||||
|
||||
def get_tag_histogram(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
include_public: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagHistogramResponse:
|
||||
# TODO: Implement actual histogram query in queries.py
|
||||
return schemas_out.TagHistogramResponse(tags=[])
|
||||
|
||||
177
tests-unit/app_test/assets_api_test.py
Normal file
177
tests-unit/app_test/assets_api_test.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Tests for Assets API endpoints (app/assets/api/routes.py)
|
||||
|
||||
Tests cover:
|
||||
- Schema validation for query parameters and request bodies
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
|
||||
|
||||
class TestListAssetsQuery:
|
||||
"""Tests for ListAssetsQuery schema."""
|
||||
|
||||
def test_defaults(self):
|
||||
"""Test default values."""
|
||||
q = schemas_in.ListAssetsQuery()
|
||||
assert q.include_tags == []
|
||||
assert q.exclude_tags == []
|
||||
assert q.limit == 20
|
||||
assert q.offset == 0
|
||||
assert q.sort == "created_at"
|
||||
assert q.order == "desc"
|
||||
assert q.include_public == True
|
||||
|
||||
def test_include_public_false(self):
|
||||
"""Test include_public can be set to False."""
|
||||
q = schemas_in.ListAssetsQuery(include_public=False)
|
||||
assert q.include_public == False
|
||||
|
||||
def test_csv_tags_parsing(self):
|
||||
"""Test comma-separated tags are parsed correctly."""
|
||||
q = schemas_in.ListAssetsQuery.model_validate({"include_tags": "a,b,c"})
|
||||
assert q.include_tags == ["a", "b", "c"]
|
||||
|
||||
def test_metadata_filter_json_string(self):
|
||||
"""Test metadata_filter accepts JSON string."""
|
||||
q = schemas_in.ListAssetsQuery.model_validate({"metadata_filter": '{"key": "value"}'})
|
||||
assert q.metadata_filter == {"key": "value"}
|
||||
|
||||
|
||||
class TestTagsListQuery:
|
||||
"""Tests for TagsListQuery schema."""
|
||||
|
||||
def test_defaults(self):
|
||||
"""Test default values."""
|
||||
q = schemas_in.TagsListQuery()
|
||||
assert q.prefix is None
|
||||
assert q.limit == 100
|
||||
assert q.offset == 0
|
||||
assert q.order == "count_desc"
|
||||
assert q.include_zero == True
|
||||
assert q.include_public == True
|
||||
|
||||
def test_include_public_false(self):
|
||||
"""Test include_public can be set to False."""
|
||||
q = schemas_in.TagsListQuery(include_public=False)
|
||||
assert q.include_public == False
|
||||
|
||||
|
||||
class TestUpdateAssetBody:
|
||||
"""Tests for UpdateAssetBody schema."""
|
||||
|
||||
def test_requires_at_least_one_field(self):
|
||||
"""Test that at least one field is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
schemas_in.UpdateAssetBody()
|
||||
|
||||
def test_name_only(self):
|
||||
"""Test updating name only."""
|
||||
body = schemas_in.UpdateAssetBody(name="new name")
|
||||
assert body.name == "new name"
|
||||
assert body.mime_type is None
|
||||
assert body.preview_id is None
|
||||
|
||||
def test_mime_type_only(self):
|
||||
"""Test updating mime_type only."""
|
||||
body = schemas_in.UpdateAssetBody(mime_type="image/png")
|
||||
assert body.mime_type == "image/png"
|
||||
|
||||
def test_preview_id_only(self):
|
||||
"""Test updating preview_id only."""
|
||||
body = schemas_in.UpdateAssetBody(preview_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert body.preview_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
def test_preview_id_invalid_uuid(self):
|
||||
"""Test invalid UUID for preview_id."""
|
||||
with pytest.raises(ValidationError):
|
||||
schemas_in.UpdateAssetBody(preview_id="not-a-uuid")
|
||||
|
||||
def test_all_fields(self):
|
||||
"""Test all fields together."""
|
||||
body = schemas_in.UpdateAssetBody(
|
||||
name="test",
|
||||
mime_type="application/json",
|
||||
preview_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
user_metadata={"key": "value"}
|
||||
)
|
||||
assert body.name == "test"
|
||||
assert body.mime_type == "application/json"
|
||||
|
||||
|
||||
class TestUploadAssetFromUrlBody:
|
||||
"""Tests for UploadAssetFromUrlBody schema (JSON URL upload)."""
|
||||
|
||||
def test_valid_url(self):
|
||||
"""Test valid HTTP URL."""
|
||||
body = schemas_in.UploadAssetFromUrlBody(
|
||||
url="https://example.com/model.safetensors",
|
||||
name="model.safetensors"
|
||||
)
|
||||
assert body.url == "https://example.com/model.safetensors"
|
||||
assert body.name == "model.safetensors"
|
||||
|
||||
def test_http_url(self):
|
||||
"""Test HTTP URL (not just HTTPS)."""
|
||||
body = schemas_in.UploadAssetFromUrlBody(
|
||||
url="http://example.com/file.bin",
|
||||
name="file.bin"
|
||||
)
|
||||
assert body.url == "http://example.com/file.bin"
|
||||
|
||||
def test_invalid_url_scheme(self):
|
||||
"""Test invalid URL scheme raises error."""
|
||||
with pytest.raises(ValidationError):
|
||||
schemas_in.UploadAssetFromUrlBody(
|
||||
url="ftp://example.com/file.bin",
|
||||
name="file.bin"
|
||||
)
|
||||
|
||||
def test_tags_normalized(self):
|
||||
"""Test tags are normalized to lowercase."""
|
||||
body = schemas_in.UploadAssetFromUrlBody(
|
||||
url="https://example.com/model.safetensors",
|
||||
name="model",
|
||||
tags=["Models", "LORAS"]
|
||||
)
|
||||
assert body.tags == ["models", "loras"]
|
||||
|
||||
|
||||
class TestTagsRefineQuery:
|
||||
"""Tests for TagsRefineQuery schema."""
|
||||
|
||||
def test_defaults(self):
|
||||
"""Test default values."""
|
||||
q = schemas_in.TagsRefineQuery()
|
||||
assert q.include_tags == []
|
||||
assert q.exclude_tags == []
|
||||
assert q.limit == 100
|
||||
assert q.include_public == True
|
||||
|
||||
def test_include_public_false(self):
|
||||
"""Test include_public can be set to False."""
|
||||
q = schemas_in.TagsRefineQuery(include_public=False)
|
||||
assert q.include_public == False
|
||||
|
||||
|
||||
class TestTagHistogramResponse:
|
||||
"""Tests for TagHistogramResponse schema."""
|
||||
|
||||
def test_empty_response(self):
|
||||
"""Test empty response."""
|
||||
resp = schemas_out.TagHistogramResponse()
|
||||
assert resp.tags == []
|
||||
|
||||
def test_with_entries(self):
|
||||
"""Test response with entries."""
|
||||
resp = schemas_out.TagHistogramResponse(
|
||||
tags=[
|
||||
schemas_out.TagHistogramEntry(name="models", count=10),
|
||||
schemas_out.TagHistogramEntry(name="loras", count=5),
|
||||
]
|
||||
)
|
||||
assert len(resp.tags) == 2
|
||||
assert resp.tags[0].name == "models"
|
||||
assert resp.tags[0].count == 10
|
||||
Reference in New Issue
Block a user