mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-09 01:30:04 +00:00
dev: Everything is Assets
This commit is contained in:
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
@@ -1,112 +1,267 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Attempt imports which may not exist in some environments
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
ENGINE: AsyncEngine | None = None
|
||||
SESSION: async_sessionmaker | None = None
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
(
|
||||
"------------------------------------------------------------------------\n"
|
||||
f"Error importing DB dependencies: {e}\n"
|
||||
f"{get_missing_requirements_message()}\n"
|
||||
"This error is happening because ComfyUI now uses a local database.\n"
|
||||
"------------------------------------------------------------------------"
|
||||
).strip()
|
||||
)
|
||||
_DB_AVAILABLE = False
|
||||
ENGINE = None
|
||||
SESSION = None
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
def dependencies_available() -> bool:
|
||||
"""Check if DB dependencies are importable."""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
def _root_paths():
|
||||
"""Resolve alembic.ini and migrations script folder."""
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
return config_path, scripts_path
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
def _absolutize_sqlite_url(db_url: str) -> str:
|
||||
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url
|
||||
|
||||
# Make path absolute if relative
|
||||
db_path = u.database or ""
|
||||
if not os.path.isabs(db_path):
|
||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||
u = u.set(database=db_path)
|
||||
return str(u)
|
||||
|
||||
|
||||
def _to_sync_driver_url(async_url: str) -> str:
|
||||
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||
u = make_url(async_url)
|
||||
driver = u.drivername
|
||||
|
||||
if driver.startswith("sqlite+aiosqlite"):
|
||||
u = u.set(drivername="sqlite")
|
||||
elif driver.startswith("postgresql+asyncpg"):
|
||||
u = u.set(drivername="postgresql")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
# Generic: strip the async driver part if present
|
||||
if "+" in driver:
|
||||
u = u.set(drivername=driver.split("+", 1)[0])
|
||||
|
||||
return str(u)
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||
"""Return the on-disk path for a SQLite URL, else None."""
|
||||
try:
|
||||
u = make_url(sync_url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
config = get_alembic_config()
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return None
|
||||
return u.database
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
def _get_alembic_config(sync_url: str) -> Config:
|
||||
"""Prepare Alembic Config with script location and DB URL."""
|
||||
config_path, scripts_path = _root_paths()
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", sync_url)
|
||||
return cfg
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
|
||||
async def init_db_engine() -> None:
|
||||
"""Initialize async engine + sessionmaker and run migrations to head.
|
||||
|
||||
This must be called once on application startup before any DB usage.
|
||||
"""
|
||||
global ENGINE, SESSION
|
||||
|
||||
if not dependencies_available():
|
||||
raise RuntimeError("Database dependencies are not available.")
|
||||
|
||||
if ENGINE is not None:
|
||||
return
|
||||
|
||||
raw_url = args.database_url
|
||||
if not raw_url:
|
||||
raise RuntimeError("Database URL is not configured.")
|
||||
|
||||
# Absolutize SQLite path for async engine
|
||||
db_url = _absolutize_sqlite_url(raw_url)
|
||||
|
||||
# Prepare async engine
|
||||
connect_args = {}
|
||||
if db_url.startswith("sqlite"):
|
||||
connect_args = {
|
||||
"check_same_thread": False,
|
||||
"timeout": 12,
|
||||
}
|
||||
|
||||
ENGINE = create_async_engine(
|
||||
db_url,
|
||||
connect_args=connect_args,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Enforce SQLite pragmas on the async engine
|
||||
if db_url.startswith("sqlite"):
|
||||
async with ENGINE.begin() as conn:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
|
||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||
|
||||
await _run_migrations(raw_url=db_url)
|
||||
|
||||
SESSION = async_sessionmaker(
|
||||
bind=ENGINE,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def _run_migrations(raw_url: str) -> None:
|
||||
"""
|
||||
Run Alembic migrations up to head.
|
||||
|
||||
We deliberately use a synchronous engine for migrations because Alembic's
|
||||
programmatic API is synchronous by default and this path is robust.
|
||||
"""
|
||||
# Convert to sync URL and make SQLite URL an absolute one
|
||||
sync_url = _to_sync_driver_url(raw_url)
|
||||
sync_url = _absolutize_sqlite_url(sync_url)
|
||||
|
||||
cfg = _get_alembic_config(sync_url)
|
||||
|
||||
# Inspect current and target heads
|
||||
engine = create_engine(sync_url, future=True)
|
||||
with engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
LOGGER.warning("Alembic: no target revision found.")
|
||||
return
|
||||
|
||||
if current_rev == target_rev:
|
||||
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
||||
return
|
||||
|
||||
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
||||
|
||||
# Optional backup for SQLite file DBs
|
||||
backup_path = None
|
||||
sqlite_path = _get_sqlite_file_path(sync_url)
|
||||
if sqlite_path and os.path.exists(sqlite_path):
|
||||
backup_path = sqlite_path + ".bkp"
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
shutil.copy(sqlite_path, backup_path)
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, target_rev)
|
||||
except Exception:
|
||||
if backup_path and os.path.exists(backup_path):
|
||||
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
||||
try:
|
||||
shutil.copy(backup_path, sqlite_path) # restore
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
except Exception as re:
|
||||
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
||||
else:
|
||||
LOGGER.exception("Error upgrading database, backup is not available.")
|
||||
raise
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
def get_engine():
|
||||
"""Return the global async engine (initialized after init_db_engine())."""
|
||||
if ENGINE is None:
|
||||
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
||||
return ENGINE
|
||||
|
||||
|
||||
def get_session_maker():
|
||||
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
||||
if SESSION is None:
|
||||
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
||||
return SESSION
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope():
|
||||
"""Async context manager for a unit of work:
|
||||
|
||||
async with session_scope() as sess:
|
||||
... use sess ...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
async with maker() as sess:
|
||||
try:
|
||||
yield sess
|
||||
await sess.commit()
|
||||
except Exception:
|
||||
await sess.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def create_session():
|
||||
"""Convenience helper to acquire a single AsyncSession instance.
|
||||
|
||||
Typical usage:
|
||||
async with (await create_session()) as sess:
|
||||
...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
return maker()
|
||||
|
||||
@@ -1,59 +1,257 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
Text,
|
||||
BigInteger,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
CheckConstraint,
|
||||
Numeric,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
sqlalchemy model representing a model file in the system.
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
This class defines the database schema for storing information about model files,
|
||||
including their type, path, hash, and when they were added to the system.
|
||||
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
|
||||
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
Attributes:
|
||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
||||
path (Text): The file path of the model relative to the type folder (primary key)
|
||||
file_name (Text): The name of the model file
|
||||
file_size (Integer): The size of the model file in bytes
|
||||
hash (Text): A hash of the model file
|
||||
hash_algorithm (Text): The algorithm used to generate the hash
|
||||
source_url (Text): The URL of the model file
|
||||
date_added (DateTime): Timestamp of when the model was added to the system
|
||||
"""
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
|
||||
foreign_keys=lambda: [AssetInfo.asset_hash],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__tablename__ = "model"
|
||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
|
||||
foreign_keys=lambda: [AssetInfo.preview_hash],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
type = Column(Text, primary_key=True)
|
||||
path = Column(Text, primary_key=True)
|
||||
file_name = Column(Text)
|
||||
file_size = Column(Integer)
|
||||
hash = Column(Text)
|
||||
hash_algorithm = Column(Text)
|
||||
source_url = Column(Text)
|
||||
date_added = Column(DateTime, server_default=func.now())
|
||||
locator_state: Mapped["AssetLocatorState | None"] = relationship(
|
||||
back_populates="asset",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert the model instance to a dictionary representation.
|
||||
__table_args__ = (
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
|
||||
)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the attributes of the model
|
||||
"""
|
||||
dict = to_dict(self)
|
||||
return dict
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
|
||||
|
||||
|
||||
class AssetLocatorState(Base):
|
||||
__tablename__ = "asset_locator_state"
|
||||
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
# For fs backends: nanosecond mtime; nullable if not applicable
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
# For HTTP/S3/GCS/Azure, etc.: optional validators
|
||||
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(128))
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
|
||||
)
|
||||
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
# Relationships
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_hash],
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_hash],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list["Tag"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="joined",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_hash", "asset_hash"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
{"sqlite_autoincrement": True},
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
|
||||
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_by: Mapped[str | None] = mapped_column(String(128))
|
||||
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
683
app/database/services.py
Normal file
683
app/database/services.py
Normal file
@@ -0,0 +1,683 @@
|
||||
import os
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, Sequence, Optional, Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session,
|
||||
*,
|
||||
file_path: str,
|
||||
size_bytes: Optional[int] = None,
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
|
||||
AND (if provided) mtime_ns matches stored locator-state,
|
||||
AND (if provided) size_bytes matches verified size when known.
|
||||
"""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = select(sa.literal(True)).select_from(Asset)
|
||||
|
||||
conditions = [
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
]
|
||||
|
||||
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
|
||||
# If verified size is 0 (unknown), we don't force equality.
|
||||
if size_bytes is not None:
|
||||
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
# If mtime_ns provided require the locator-state to exist and match.
|
||||
if mtime_ns is not None:
|
||||
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
|
||||
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
|
||||
|
||||
stmt = stmt.where(*conditions).limit(1)
|
||||
|
||||
row = (await session.execute(stmt)).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def ingest_fs_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: Optional[str] = None,
|
||||
info_name: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
preview_hash: Optional[str] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates or updates Asset record for a local (fs) asset.
|
||||
|
||||
Always:
|
||||
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
|
||||
- Insert AssetLocatorState if missing; else update mtime_ns if different.
|
||||
|
||||
Optionally (when info_name is provided):
|
||||
- Create an AssetInfo (no refcount changes).
|
||||
- Link provided tags to that AssetInfo.
|
||||
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
|
||||
* If False (default), silently skips unknown tags.
|
||||
|
||||
Returns flags and ids:
|
||||
{
|
||||
"asset_created": bool,
|
||||
"asset_updated": bool,
|
||||
"state_created": bool,
|
||||
"state_updated": bool,
|
||||
"asset_info_id": int | None,
|
||||
"tags_added": list[str],
|
||||
"tags_missing": list[str], # filled only when require_existing_tags=False
|
||||
}
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
datetime_now = datetime.now(timezone.utc)
|
||||
|
||||
out = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
"tags_added": [],
|
||||
"tags_missing": [],
|
||||
}
|
||||
|
||||
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
|
||||
async with session.begin_nested() as sp1:
|
||||
try:
|
||||
session.add(
|
||||
Asset(
|
||||
hash=asset_hash,
|
||||
size_bytes=int(size_bytes),
|
||||
mime_type=mime_type,
|
||||
refcount=0,
|
||||
storage_backend="fs",
|
||||
storage_locator=locator,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
out["asset_created"] = True
|
||||
except IntegrityError:
|
||||
await sp1.rollback()
|
||||
# Already exists by hash -> update selected fields if different
|
||||
existing = await session.get(Asset, asset_hash)
|
||||
if existing is not None:
|
||||
desired_size = int(size_bytes)
|
||||
if existing.size_bytes != desired_size:
|
||||
existing.size_bytes = desired_size
|
||||
existing.updated_at = datetime_now
|
||||
out["asset_updated"] = True
|
||||
else:
|
||||
# This should not occur. Log for visibility.
|
||||
logging.error("Asset %s not found after conflict; skipping update.", asset_hash)
|
||||
except Exception:
|
||||
await sp1.rollback()
|
||||
logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator)
|
||||
raise
|
||||
|
||||
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
|
||||
async with session.begin_nested() as sp2:
|
||||
try:
|
||||
session.add(
|
||||
AssetLocatorState(
|
||||
asset_hash=asset_hash,
|
||||
mtime_ns=int(mtime_ns),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
out["state_created"] = True
|
||||
except IntegrityError:
|
||||
await sp2.rollback()
|
||||
state = await session.get(AssetLocatorState, asset_hash)
|
||||
if state is not None:
|
||||
desired_mtime = int(mtime_ns)
|
||||
if state.mtime_ns != desired_mtime:
|
||||
state.mtime_ns = desired_mtime
|
||||
out["state_updated"] = True
|
||||
else:
|
||||
logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash)
|
||||
except Exception:
|
||||
await sp2.rollback()
|
||||
logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash)
|
||||
raise
|
||||
|
||||
# ---- Optional: AssetInfo + tag links ----
|
||||
if info_name:
|
||||
# 2a) Create AssetInfo (no refcount bump)
|
||||
async with session.begin_nested() as sp3:
|
||||
try:
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_hash=asset_hash,
|
||||
preview_hash=preview_hash,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
last_access_time=datetime_now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush() # get info.id
|
||||
out["asset_info_id"] = info.id
|
||||
except Exception:
|
||||
await sp3.rollback()
|
||||
logging.exception(
|
||||
"Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name
|
||||
)
|
||||
raise
|
||||
|
||||
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
# Which tags exist?
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
# Which links already exist?
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_by=added_by,
|
||||
added_at=datetime_now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
out["tags_added"] = to_add
|
||||
out["tags_missing"] = missing
|
||||
|
||||
# 2c) Rebuild metadata projection if provided
|
||||
if user_metadata is not None and out["asset_info_id"] is not None:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
async def touch_asset_infos_by_fs_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
abs_path: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> int:
|
||||
locator = os.path.abspath(abs_path)
|
||||
ts = ts or datetime.now(timezone.utc)
|
||||
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(Asset)
|
||||
.where(
|
||||
Asset.hash == AssetInfo.asset_hash,
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None),
|
||||
AssetInfo.last_access_time < ts,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.values(last_access_time=ts)
|
||||
|
||||
res = await session.execute(stmt)
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
||||
"""
|
||||
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
||||
plus a map of asset_info_id -> [tags], and the total count.
|
||||
|
||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||
"""
|
||||
# Clamp
|
||||
if limit <= 0:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
|
||||
# Build base query
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.options(contains_eager(AssetInfo.asset))
|
||||
)
|
||||
|
||||
# Filters
|
||||
if name_contains:
|
||||
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
# Sort
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
# Total count (same filters, no ordering/limit/offset)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
)
|
||||
if name_contains:
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = (await session.execute(count_stmt)).scalar_one()
|
||||
|
||||
# Fetch rows
|
||||
infos = (await session.execute(base)).scalars().unique().all()
|
||||
|
||||
# Collect tags in bulk (single query)
|
||||
id_list = [i.id for i in infos]
|
||||
tag_map: dict[int, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = await session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Replace the tag set on an AssetInfo with `tags`. Idempotent.
|
||||
Creates missing tag names as 'user'.
|
||||
"""
|
||||
desired = _normalize_tags(tags)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# current links
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
await _ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
async def update_asset_info_full(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
) -> AssetInfo:
|
||||
"""
|
||||
Update AssetInfo fields:
|
||||
- name (if provided)
|
||||
- user_metadata blob + rebuild projection (if provided)
|
||||
- replace tags with provided set (if provided)
|
||||
Returns the updated AssetInfo.
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
if user_metadata is not None:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=user_metadata
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
added_by=added_by,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
user_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
await session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in _project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||
wanted = _normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
by_name = {t.name: t for t in existing}
|
||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||
if to_create:
|
||||
session.add_all(to_create)
|
||||
await session.flush()
|
||||
by_name.update({t.name: t for t in to_create})
|
||||
return [by_name[n] for n in wanted]
|
||||
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None,
|
||||
exclude_tags: Sequence[str] | None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = _normalize_tags(include_tags)
|
||||
exclude_tags = _normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply metadata filters using the projection table asset_info_meta.
|
||||
|
||||
Semantics:
|
||||
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
|
||||
- For None: key is missing OR key has explicit null (val_json IS NULL).
|
||||
- For list values: ANY-of the list elements matches (EXISTS for any).
|
||||
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
|
||||
"""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
subquery = (
|
||||
select(sa.literal(1))
|
||||
.select_from(AssetInfoMeta)
|
||||
.where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return sa.exists(subquery)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
# Missing OR null:
|
||||
if value is None:
|
||||
# either: no row for key OR a row for key with explicit null
|
||||
no_row_for_key = ~sa.exists(
|
||||
select(sa.literal(1))
|
||||
.select_from(AssetInfoMeta)
|
||||
.where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
# Typed scalar matches:
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
# store as Decimal for equality against NUMERIC(38,10)
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
|
||||
# Complex: compare JSON (no index, but supported)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
# ANY-of (exists for any element)
|
||||
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def _is_scalar(v: Any) -> bool:
|
||||
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
"""
|
||||
Turn a metadata key/value into one or more projection rows:
|
||||
- scalar -> one row (ordinal=0) in the proper typed column
|
||||
- list of scalars -> one row per element with ordinal=i
|
||||
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
|
||||
- None -> single row with val_json = None
|
||||
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
# None
|
||||
if value is None:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": None})
|
||||
return rows
|
||||
|
||||
# Scalars
|
||||
if _is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
# store numeric; SQLAlchemy will coerce to Numeric
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": value})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
# Fallback to json
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
# Lists
|
||||
if isinstance(value, list):
|
||||
# list of scalars?
|
||||
if all(_is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": None})
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
rows.append({"key": key, "ordinal": i, "val_num": x})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
# list contains objects -> one val_json per element
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
# Dict or any other structure -> single json row
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
Reference in New Issue
Block a user