mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-02 19:59:52 +00:00
add "--multi-user" support
This commit is contained in:
@@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
|
||||
from .timeutil import utcnow
|
||||
from .._assets_helpers import normalize_tags
|
||||
from .._assets_helpers import normalize_tags, visible_owner_clause
|
||||
|
||||
|
||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
@@ -30,6 +30,10 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option
|
||||
return await session.get(Asset, asset_hash)
|
||||
|
||||
|
||||
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]:
|
||||
return await session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session,
|
||||
*,
|
||||
@@ -317,6 +321,7 @@ async def touch_asset_info_by_id(
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
owner_id: str = "",
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
@@ -326,26 +331,18 @@ async def list_asset_infos_page(
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
||||
"""
|
||||
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
||||
plus a map of asset_info_id -> [tags], and the total count.
|
||||
|
||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||
"""
|
||||
|
||||
# Build base query
|
||||
"""Return page of AssetInfo rows in the viewers visibility."""
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.options(contains_eager(AssetInfo.asset))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
# Filters
|
||||
if name_contains:
|
||||
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
# Sort
|
||||
@@ -368,13 +365,14 @@ async def list_asset_infos_page(
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = (await session.execute(count_stmt)).scalar_one()
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
# Fetch rows
|
||||
infos = (await session.execute(base)).scalars().unique().all()
|
||||
@@ -394,13 +392,22 @@ async def list_asset_infos_page(
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
row = await session.execute(
|
||||
async def fetch_asset_info_and_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.where(AssetInfo.id == asset_info_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
row = await session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
@@ -411,18 +418,17 @@ async def fetch_asset_info_asset_and_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||
"""Fetch AssetInfo, its Asset, and all tag names.
|
||||
|
||||
Returns:
|
||||
(AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist.
|
||||
"""
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(AssetInfo.id == asset_info_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
@@ -511,11 +517,12 @@ async def create_asset_info_for_existing_asset(
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create a new AssetInfo referencing an existing Asset (no content write)."""
|
||||
now = utcnow()
|
||||
info = AssetInfo(
|
||||
owner_id="",
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_hash=asset_hash,
|
||||
preview_hash=None,
|
||||
@@ -593,6 +600,7 @@ async def update_asset_info_full(
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
"""
|
||||
Update AssetInfo fields:
|
||||
@@ -601,9 +609,12 @@ async def update_asset_info_full(
|
||||
- replace tags with provided set (if provided)
|
||||
Returns the updated AssetInfo.
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
@@ -633,9 +644,12 @@ async def update_asset_info_full(
|
||||
return info
|
||||
|
||||
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool:
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool:
|
||||
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
|
||||
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id))
|
||||
res = await session.execute(delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
))
|
||||
return bool(res.rowcount)
|
||||
|
||||
|
||||
@@ -691,25 +705,24 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s
|
||||
|
||||
|
||||
async def list_tags_with_usage(
|
||||
session,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc", # "count_desc" | "name_asc"
|
||||
order: str = "count_desc", # "count_desc" | "name_asc"
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
"""
|
||||
Returns:
|
||||
rows: list of (name, tag_type, count)
|
||||
total: number of tags matching filter (independent of pagination)
|
||||
"""
|
||||
# Subquery with counts by tag_name
|
||||
# Subquery with counts by tag_name and owner_id
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
@@ -765,14 +778,16 @@ async def add_tags_to_asset_info(
|
||||
origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
"""Adds tags to an AssetInfo.
|
||||
If create_if_missing=True, missing tag rows are created as 'user'.
|
||||
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
|
||||
Reference in New Issue
Block a user