mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-27 10:24:06 +00:00
dev: refactor; populate models in more nodes; use Pydantic in endpoints for input validation
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from typing import Sequence
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
|
||||
from app import assets_manager
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .. import assets_manager
|
||||
from .schemas_in import ListAssetsQuery, UpdateAssetBody
|
||||
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@@ -10,38 +12,22 @@ ROUTES = web.RouteTableDef()
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
q = request.rel_url.query
|
||||
query_dict = dict(request.rel_url.query)
|
||||
|
||||
include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags"))
|
||||
exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags"))
|
||||
name_contains = q.get("name_contains")
|
||||
|
||||
# Optional JSON metadata filter (top-level key equality only for now)
|
||||
metadata_filter = None
|
||||
raw_meta = q.get("metadata_filter")
|
||||
if raw_meta:
|
||||
try:
|
||||
metadata_filter = json.loads(raw_meta)
|
||||
if not isinstance(metadata_filter, dict):
|
||||
metadata_filter = None
|
||||
except Exception:
|
||||
# Silently ignore malformed JSON for first iteration; could 400 in future
|
||||
metadata_filter = None
|
||||
|
||||
limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100)
|
||||
offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000)
|
||||
sort = q.get("sort", "created_at")
|
||||
order = q.get("order", "desc")
|
||||
try:
|
||||
q = ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = await assets_manager.list_assets(
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
)
|
||||
return web.json_response(payload)
|
||||
|
||||
@@ -55,29 +41,18 @@ async def update_asset(request: web.Request) -> web.Response:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
name = payload.get("name", None)
|
||||
tags = payload.get("tags", None)
|
||||
user_metadata = payload.get("user_metadata", None)
|
||||
|
||||
if name is None and tags is None and user_metadata is None:
|
||||
return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.")
|
||||
|
||||
if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)):
|
||||
return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.")
|
||||
|
||||
if user_metadata is not None and not isinstance(user_metadata, dict):
|
||||
return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
@@ -90,21 +65,9 @@ def register_assets_routes(app: web.Application) -> None:
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
def _parse_csv_tags(raw: str | None) -> list[str]:
|
||||
if not raw:
|
||||
return []
|
||||
return [t.strip() for t in raw.split(",") if t.strip()]
|
||||
|
||||
|
||||
def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int:
|
||||
if not qval:
|
||||
return default
|
||||
try:
|
||||
v = int(qval)
|
||||
except Exception:
|
||||
return default
|
||||
return max(lo, min(hi, v))
|
||||
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.errors()})
|
||||
|
||||
66
app/api/schemas_in.py
Normal file
66
app/api/schemas_in.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Literal
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator, conint
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: Optional[str] = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: Optional[dict[str, Any]] = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
import json
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
user_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.tags is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, tags, user_metadata.")
|
||||
if self.tags is not None:
|
||||
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
||||
raise ValueError("Field 'tags' must be an array of strings.")
|
||||
return self
|
||||
Reference in New Issue
Block a user