mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-26 17:39:15 +00:00
use Pydantic for output; finished Tags endpoints
This commit is contained in:
@@ -493,7 +493,7 @@ async def replace_asset_info_metadata_projection(
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
@@ -504,6 +504,179 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[T
|
||||
]
|
||||
|
||||
|
||||
async def list_tags_with_usage(
|
||||
session,
|
||||
*,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc", # "count_desc" | "name_asc"
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
"""
|
||||
Returns:
|
||||
rows: list of (name, tag_type, count)
|
||||
total: number of tags matching filter (independent of pagination)
|
||||
"""
|
||||
# Subquery with counts by tag_name
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Base select with LEFT JOIN so we can include zero-usage tags
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
# Prefix filter (tags are lowercase by check constraint)
|
||||
if prefix:
|
||||
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||
|
||||
# Include_zero toggles: if False, drop zero-usage tags
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
# Ordering
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else: # default "count_desc"
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
# Total (without limit/offset, same filters)
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||
if not include_zero:
|
||||
# count only names that appear in counts subquery
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (await session.execute(total_q)).scalar_one()
|
||||
|
||||
# Normalize counts to int for SQLite/Postgres consistency
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
async def add_tags_to_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
create_if_missing: bool = True,
|
||||
) -> dict:
|
||||
"""Adds tags to an AssetInfo.
|
||||
If create_if_missing=True, missing tag rows are created as 'user'.
|
||||
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = _normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
# Ensure tag rows exist if requested.
|
||||
if create_if_missing:
|
||||
await _ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
# Current links
|
||||
existing = {
|
||||
tname
|
||||
for (tname,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_add = [t for t in norm if t not in existing]
|
||||
already = [t for t in norm if t in existing]
|
||||
|
||||
if to_add:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make insert race-safe with a nested tx; ignore dup conflicts if any.
|
||||
async with session.begin_nested():
|
||||
session.add_all([
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_by=added_by,
|
||||
added_at=now,
|
||||
) for t in to_add
|
||||
])
|
||||
try:
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Another writer linked the same tag at the same time -> ok, treat as already present.
|
||||
await session.rollback()
|
||||
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": sorted(set(to_add)), "already_present": sorted(set(already)), "total_tags": total}
|
||||
|
||||
|
||||
async def remove_tags_from_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
"""Removes tags from an AssetInfo.
|
||||
Returns: {"removed": [...], "not_present": [...], "total_tags": [...]}
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = _normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tname
|
||||
for (tname,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user