diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index fbeab392a..b59eeae44 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -67,10 +67,29 @@ async def list_assets(request: web.Request) -> web.Response: sort=q.sort, order=q.order, owner_id=USER_MANAGER.get_request_user_id(request), + include_public=q.include_public, ) return web.json_response(payload.model_dump(mode="json")) +@ROUTES.get("/api/assets/remote-metadata") +async def get_remote_asset_metadata(request: web.Request) -> web.Response: + """ + Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading. + Status: Not implemented yet. + """ + return _error_response(501, "NOT_IMPLEMENTED", "Remote metadata fetching is not yet supported.") + + +@ROUTES.post("/api/assets/download") +async def create_asset_download(request: web.Request) -> web.Response: + """ + Initiate background download job for large files from HuggingFace or CivitAI. + Status: Not implemented yet. + """ + return _error_response(501, "NOT_IMPLEMENTED", "Background asset download is not yet supported.") + + @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") async def get_asset(request: web.Request) -> web.Response: """ @@ -146,10 +165,22 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: @ROUTES.post("/api/assets") async def upload_asset(request: web.Request) -> web.Response: - """Multipart/form-data endpoint for Asset uploads.""" + """Asset upload endpoint supporting multipart/form-data (file upload) or application/json (URL-based).""" - if not (request.content_type or "").lower().startswith("multipart/"): - return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") + content_type = (request.content_type or "").lower() + + if content_type.startswith("application/json"): + try: + payload = await request.json() + schemas_in.UploadAssetFromUrlBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _error_response(501, "NOT_IMPLEMENTED", "URL-based asset upload is not yet supported. Use multipart/form-data file upload.") + + if not content_type.startswith("multipart/"): + return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads or application/json for URL-based uploads.") reader = await request.multipart() @@ -340,7 +371,8 @@ async def update_asset(request: web.Request) -> web.Response: result = manager.update_asset( asset_info_id=asset_info_id, name=body.name, - tags=body.tags, + mime_type=body.mime_type, + preview_id=body.preview_id, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), ) @@ -356,34 +388,6 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) -@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview") -async def set_asset_preview(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) - try: - body = schemas_in.SetPreviewBody.model_validate(await request.json()) - except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) - except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") - - try: - result = manager.set_asset_preview( - asset_info_id=asset_info_id, - preview_asset_id=body.preview_id, - owner_id=USER_MANAGER.get_request_user_id(request), - ) - except (PermissionError, ValueError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) - except Exception: - logging.exception( - "set_asset_preview failed for asset_info_id=%s, owner_id=%s", - asset_info_id, - USER_MANAGER.get_request_user_id(request), - ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) - - @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") async def delete_asset(request: web.Request) -> web.Response: asset_info_id = str(uuid.UUID(request.match_info["id"])) @@ -431,6 +435,7 @@ async def get_tags(request: web.Request) -> web.Response: order=query.order, include_zero=query.include_zero, owner_id=USER_MANAGER.get_request_user_id(request), + include_public=query.include_public, ) return web.json_response(result.model_dump(mode="json")) @@ -495,6 +500,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.get("/api/assets/tags/refine") +async def get_asset_tag_histogram(request: web.Request) -> web.Response: + """ + GET request to get a tag histogram for filtered assets. + """ + query_dict = get_query_dict(request) + try: + q = schemas_in.TagsRefineQuery.model_validate(query_dict) + except ValidationError as ve: + return _validation_error_response("INVALID_QUERY", ve) + + payload = manager.get_tag_histogram( + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + include_public=q.include_public, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(payload.model_dump(mode="json")) + + @ROUTES.post("/api/assets/scan/seed") async def seed_assets(request: web.Request) -> web.Response: try: diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index b0f731074..7f7fa9a85 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -27,6 +27,8 @@ class ListAssetsQuery(BaseModel): sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" order: Literal["asc", "desc"] = "desc" + include_public: bool = True + @field_validator("include_tags", "exclude_tags", mode="before") @classmethod def _split_csv_tags(cls, v): @@ -61,16 +63,28 @@ class ListAssetsQuery(BaseModel): class UpdateAssetBody(BaseModel): name: str | None = None - tags: list[str] | None = None + mime_type: str | None = None + preview_id: str | None = None user_metadata: dict[str, Any] | None = None + @field_validator("preview_id", mode="before") + @classmethod + def _norm_uuid(cls, v): + if v is None: + return None + s = str(v).strip() + if not s: + return None + try: + uuid.UUID(s) + except Exception: + raise ValueError("preview_id must be a UUID") + return s + @model_validator(mode="after") def _at_least_one(self): - if self.name is None and self.tags is None and self.user_metadata is None: - raise ValueError("Provide at least one of: name, tags, user_metadata.") - if self.tags is not None: - if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags): - raise ValueError("Field 'tags' must be an array of strings.") + if self.name is None and self.mime_type is None and self.preview_id is None and self.user_metadata is None: + raise ValueError("Provide at least one of: name, mime_type, preview_id, user_metadata.") return self @@ -122,6 +136,7 @@ class TagsListQuery(BaseModel): offset: int = Field(0, ge=0, le=10_000_000) order: Literal["count_desc", "name_asc"] = "count_desc" include_zero: bool = True + include_public: bool = True @field_validator("prefix") @classmethod @@ -271,10 +286,49 @@ class UploadAssetSpec(BaseModel): return self -class SetPreviewBody(BaseModel): - """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" +class UploadAssetFromUrlBody(BaseModel): + """JSON body for URL-based asset upload.""" + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + url: str = Field(..., description="HTTP/HTTPS URL to download the asset from") + name: str = Field(..., max_length=512, description="Display name for the asset") + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) preview_id: str | None = None + @field_validator("url") + @classmethod + def _validate_url(cls, v): + s = (v or "").strip() + if not s: + raise ValueError("url must not be empty") + if not (s.startswith("http://") or s.startswith("https://")): + raise ValueError("url must start with http:// or https://") + return s + + @field_validator("tags", mode="before") + @classmethod + def _parse_tags(cls, v): + if v is None: + return [] + if isinstance(v, list): + out = [str(t).strip().lower() for t in v if str(t).strip()] + seen = set() + dedup = [] + for t in out: + if t not in seen: + seen.add(t) + dedup.append(t) + return dedup + return [] + + @field_validator("user_metadata", mode="before") + @classmethod + def _parse_metadata(cls, v): + if v is None or isinstance(v, dict): + return v or {} + return {} + @field_validator("preview_id", mode="before") @classmethod def _norm_uuid(cls, v): @@ -292,3 +346,45 @@ class SetPreviewBody(BaseModel): class ScheduleAssetScanBody(BaseModel): roots: list[RootType] = Field(..., min_length=1) + + +class TagsRefineQuery(BaseModel): + """Query parameters for tag histogram/refinement endpoint.""" + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + metadata_filter: dict[str, Any] | None = None + limit: conint(ge=1, le=1000) = 100 + include_public: bool = True + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index b6fb3da0c..8176d06a6 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -91,3 +91,12 @@ class TagsRemove(BaseModel): removed: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list) + + +class TagHistogramEntry(BaseModel): + name: str + count: int + + +class TagHistogramResponse(BaseModel): + tags: list[TagHistogramEntry] = Field(default_factory=list) diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py index 441a4e14c..511446dbf 100644 --- a/app/assets/database/queries.py +++ b/app/assets/database/queries.py @@ -683,9 +683,8 @@ def update_asset_info_full( *, asset_info_id: str, name: str | None = None, - tags: Sequence[str] | None = None, + mime_type: str | None = None, user_metadata: dict | None = None, - tag_origin: str = "manual", asset_info_row: Any = None, ) -> AssetInfo: if not asset_info_row: @@ -700,6 +699,11 @@ def update_asset_info_full( info.name = name touched = True + if mime_type is not None and info.asset: + if info.asset.mime_type != mime_type: + info.asset.mime_type = mime_type + touched = True + computed_filename = None try: p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) @@ -727,15 +731,6 @@ def update_asset_info_full( ) touched = True - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=tag_origin, - ) - touched = True - if touched and user_metadata is None: info.updated_at = utcnow() session.flush() diff --git a/app/assets/manager.py b/app/assets/manager.py index b1ed35815..1c8e23253 100644 --- a/app/assets/manager.py +++ b/app/assets/manager.py @@ -71,6 +71,7 @@ def list_assets( sort: str = "created_at", order: str = "desc", owner_id: str = "", + include_public: bool = True, ) -> schemas_out.AssetsList: sort = _safe_sort_field(sort) order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() @@ -289,7 +290,8 @@ def update_asset( *, asset_info_id: str, name: str | None = None, - tags: list[str] | None = None, + mime_type: str | None = None, + preview_id: str | None = None, user_metadata: dict | None = None, owner_id: str = "", ) -> schemas_out.AssetUpdated: @@ -304,12 +306,18 @@ def update_asset( session, asset_info_id=asset_info_id, name=name, - tags=tags, + mime_type=mime_type, user_metadata=user_metadata, - tag_origin="manual", asset_info_row=info_row, ) + if preview_id is not None: + set_asset_info_preview( + session, + asset_info_id=asset_info_id, + preview_asset_id=preview_id if preview_id else None, + ) + tag_names = get_asset_tags(session, asset_info_id=asset_info_id) session.commit() @@ -490,6 +498,7 @@ def list_tags( order: str = "count_desc", include_zero: bool = True, owner_id: str = "", + include_public: bool = True, ) -> schemas_out.TagsList: limit = max(1, min(1000, limit)) offset = max(0, offset) @@ -507,3 +516,17 @@ def list_tags( tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) + + +def get_tag_histogram( + *, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 100, + include_public: bool = True, + owner_id: str = "", +) -> schemas_out.TagHistogramResponse: + # TODO: Implement actual histogram query in queries.py + return schemas_out.TagHistogramResponse(tags=[]) diff --git a/tests-unit/app_test/assets_api_test.py b/tests-unit/app_test/assets_api_test.py new file mode 100644 index 000000000..e6ab2856a --- /dev/null +++ b/tests-unit/app_test/assets_api_test.py @@ -0,0 +1,177 @@ +"""Tests for Assets API endpoints (app/assets/api/routes.py) + +Tests cover: +- Schema validation for query parameters and request bodies +""" + +import pytest +from pydantic import ValidationError + +from app.assets.api import schemas_in, schemas_out + + +class TestListAssetsQuery: + """Tests for ListAssetsQuery schema.""" + + def test_defaults(self): + """Test default values.""" + q = schemas_in.ListAssetsQuery() + assert q.include_tags == [] + assert q.exclude_tags == [] + assert q.limit == 20 + assert q.offset == 0 + assert q.sort == "created_at" + assert q.order == "desc" + assert q.include_public == True + + def test_include_public_false(self): + """Test include_public can be set to False.""" + q = schemas_in.ListAssetsQuery(include_public=False) + assert q.include_public == False + + def test_csv_tags_parsing(self): + """Test comma-separated tags are parsed correctly.""" + q = schemas_in.ListAssetsQuery.model_validate({"include_tags": "a,b,c"}) + assert q.include_tags == ["a", "b", "c"] + + def test_metadata_filter_json_string(self): + """Test metadata_filter accepts JSON string.""" + q = schemas_in.ListAssetsQuery.model_validate({"metadata_filter": '{"key": "value"}'}) + assert q.metadata_filter == {"key": "value"} + + +class TestTagsListQuery: + """Tests for TagsListQuery schema.""" + + def test_defaults(self): + """Test default values.""" + q = schemas_in.TagsListQuery() + assert q.prefix is None + assert q.limit == 100 + assert q.offset == 0 + assert q.order == "count_desc" + assert q.include_zero == True + assert q.include_public == True + + def test_include_public_false(self): + """Test include_public can be set to False.""" + q = schemas_in.TagsListQuery(include_public=False) + assert q.include_public == False + + +class TestUpdateAssetBody: + """Tests for UpdateAssetBody schema.""" + + def test_requires_at_least_one_field(self): + """Test that at least one field is required.""" + with pytest.raises(ValidationError): + schemas_in.UpdateAssetBody() + + def test_name_only(self): + """Test updating name only.""" + body = schemas_in.UpdateAssetBody(name="new name") + assert body.name == "new name" + assert body.mime_type is None + assert body.preview_id is None + + def test_mime_type_only(self): + """Test updating mime_type only.""" + body = schemas_in.UpdateAssetBody(mime_type="image/png") + assert body.mime_type == "image/png" + + def test_preview_id_only(self): + """Test updating preview_id only.""" + body = schemas_in.UpdateAssetBody(preview_id="550e8400-e29b-41d4-a716-446655440000") + assert body.preview_id == "550e8400-e29b-41d4-a716-446655440000" + + def test_preview_id_invalid_uuid(self): + """Test invalid UUID for preview_id.""" + with pytest.raises(ValidationError): + schemas_in.UpdateAssetBody(preview_id="not-a-uuid") + + def test_all_fields(self): + """Test all fields together.""" + body = schemas_in.UpdateAssetBody( + name="test", + mime_type="application/json", + preview_id="550e8400-e29b-41d4-a716-446655440000", + user_metadata={"key": "value"} + ) + assert body.name == "test" + assert body.mime_type == "application/json" + + +class TestUploadAssetFromUrlBody: + """Tests for UploadAssetFromUrlBody schema (JSON URL upload).""" + + def test_valid_url(self): + """Test valid HTTP URL.""" + body = schemas_in.UploadAssetFromUrlBody( + url="https://example.com/model.safetensors", + name="model.safetensors" + ) + assert body.url == "https://example.com/model.safetensors" + assert body.name == "model.safetensors" + + def test_http_url(self): + """Test HTTP URL (not just HTTPS).""" + body = schemas_in.UploadAssetFromUrlBody( + url="http://example.com/file.bin", + name="file.bin" + ) + assert body.url == "http://example.com/file.bin" + + def test_invalid_url_scheme(self): + """Test invalid URL scheme raises error.""" + with pytest.raises(ValidationError): + schemas_in.UploadAssetFromUrlBody( + url="ftp://example.com/file.bin", + name="file.bin" + ) + + def test_tags_normalized(self): + """Test tags are normalized to lowercase.""" + body = schemas_in.UploadAssetFromUrlBody( + url="https://example.com/model.safetensors", + name="model", + tags=["Models", "LORAS"] + ) + assert body.tags == ["models", "loras"] + + +class TestTagsRefineQuery: + """Tests for TagsRefineQuery schema.""" + + def test_defaults(self): + """Test default values.""" + q = schemas_in.TagsRefineQuery() + assert q.include_tags == [] + assert q.exclude_tags == [] + assert q.limit == 100 + assert q.include_public == True + + def test_include_public_false(self): + """Test include_public can be set to False.""" + q = schemas_in.TagsRefineQuery(include_public=False) + assert q.include_public == False + + +class TestTagHistogramResponse: + """Tests for TagHistogramResponse schema.""" + + def test_empty_response(self): + """Test empty response.""" + resp = schemas_out.TagHistogramResponse() + assert resp.tags == [] + + def test_with_entries(self): + """Test response with entries.""" + resp = schemas_out.TagHistogramResponse( + tags=[ + schemas_out.TagHistogramEntry(name="models", count=10), + schemas_out.TagHistogramEntry(name="loras", count=5), + ] + ) + assert len(resp.tags) == 2 + assert resp.tags[0].name == "models" + assert resp.tags[0].count == 10