mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-27 18:34:20 +00:00
global refactoring; add support for Assets without the computed hash
This commit is contained in:
23
app/database/helpers/__init__.py
Normal file
23
app/database/helpers/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .filters import apply_metadata_filter, apply_tag_filters
|
||||
from .ownership import visible_owner_clause
|
||||
from .projection import is_scalar, project_kv
|
||||
from .tags import (
|
||||
add_missing_tag_for_asset_hash,
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
remove_missing_tag_for_asset_hash,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"apply_tag_filters",
|
||||
"apply_metadata_filter",
|
||||
"is_scalar",
|
||||
"project_kv",
|
||||
"ensure_tags_exist",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"add_missing_tag_for_asset_hash",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_missing_tag_for_asset_hash",
|
||||
"visible_owner_clause",
|
||||
]
|
||||
87
app/database/helpers/filters.py
Normal file
87
app/database/helpers/filters.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from ..._assets_helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Optional[Sequence[str]],
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
12
app/database/helpers/ownership.py
Normal file
12
app/database/helpers/ownership.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from ..models import AssetInfo
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
64
app/database/helpers/projection.py
Normal file
64
app/database/helpers/projection.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
102
app/database/helpers/tags.py
Normal file
102
app/database/helpers/tags.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..._assets_helpers import normalize_tags
|
||||
from ..models import Asset, AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
|
||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
by_name = {t.name: t for t in existing}
|
||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||
if to_create:
|
||||
session.add_all(to_create)
|
||||
await session.flush()
|
||||
by_name.update({t.name: t for t in to_create})
|
||||
return [by_name[n] for n in wanted]
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> int:
|
||||
"""Ensure every AssetInfo for asset_id has 'missing' tag."""
|
||||
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
||||
if not ids:
|
||||
return 0
|
||||
|
||||
existing = {
|
||||
asset_info_id
|
||||
for (asset_info_id,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.asset_info_id).where(
|
||||
AssetInfoTag.asset_info_id.in_(ids),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
to_add = [i for i in ids if i not in existing]
|
||||
if not to_add:
|
||||
return 0
|
||||
|
||||
now = utcnow()
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now)
|
||||
for i in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
return len(to_add)
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_hash(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
origin: str = "automatic",
|
||||
) -> int:
|
||||
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
|
||||
if not asset:
|
||||
return 0
|
||||
return await add_missing_tag_for_asset_id(session, asset_id=asset.id, origin=origin)
|
||||
|
||||
|
||||
async def remove_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> int:
|
||||
"""Remove the 'missing' tag from all AssetInfos for asset_id."""
|
||||
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
||||
if not ids:
|
||||
return 0
|
||||
|
||||
res = await session.execute(
|
||||
delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(ids),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def remove_missing_tag_for_asset_hash(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> int:
|
||||
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
|
||||
if not asset:
|
||||
return 0
|
||||
return await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
Reference in New Issue
Block a user