use Pydantic for output; finished Tags endpoints

This commit is contained in:
bigcat88
2025-08-24 11:02:30 +03:00
parent 5c1b5973ac
commit 8d46bec951
6 changed files with 473 additions and 39 deletions

View File

@@ -493,7 +493,7 @@ async def replace_asset_info_metadata_projection(
await session.flush()
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[str]:
return [
tag_name
for (tag_name,) in (
@@ -504,6 +504,179 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[T
]
async def list_tags_with_usage(
session,
*,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc", # "count_desc" | "name_asc"
) -> tuple[list[tuple[str, str, int]], int]:
"""
Returns:
rows: list of (name, tag_type, count)
total: number of tags matching filter (independent of pagination)
"""
# Subquery with counts by tag_name
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.group_by(AssetInfoTag.tag_name)
.subquery()
)
# Base select with LEFT JOIN so we can include zero-usage tags
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
# Prefix filter (tags are lowercase by check constraint)
if prefix:
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
# Include_zero toggles: if False, drop zero-usage tags
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
# Ordering
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else: # default "count_desc"
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
# Total (without limit/offset, same filters)
total_q = select(func.count()).select_from(Tag)
if prefix:
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
if not include_zero:
# count only names that appear in counts subquery
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (await session.execute(q.limit(limit).offset(offset))).all()
total = (await session.execute(total_q)).scalar_one()
# Normalize counts to int for SQLite/Postgres consistency
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
async def add_tags_to_asset_info(
session: AsyncSession,
*,
asset_info_id: int,
tags: Sequence[str],
origin: str = "manual",
added_by: Optional[str] = None,
create_if_missing: bool = True,
) -> dict:
"""Adds tags to an AssetInfo.
If create_if_missing=True, missing tag rows are created as 'user'.
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
"""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = _normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
# Ensure tag rows exist if requested.
if create_if_missing:
await _ensure_tags_exist(session, norm, tag_type="user")
# Current links
existing = {
tname
for (tname,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_add = [t for t in norm if t not in existing]
already = [t for t in norm if t in existing]
if to_add:
now = datetime.now(timezone.utc)
# Make insert race-safe with a nested tx; ignore dup conflicts if any.
async with session.begin_nested():
session.add_all([
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_by=added_by,
added_at=now,
) for t in to_add
])
try:
await session.flush()
except IntegrityError:
# Another writer linked the same tag at the same time -> ok, treat as already present.
await session.rollback()
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": sorted(set(to_add)), "already_present": sorted(set(already)), "total_tags": total}
async def remove_tags_from_asset_info(
session: AsyncSession,
*,
asset_info_id: int,
tags: Sequence[str],
) -> dict:
"""Removes tags from an AssetInfo.
Returns: {"removed": [...], "not_present": [...], "total_tags": [...]}
"""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = _normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tname
for (tname,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
await session.flush()
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]