dev: Everything is Assets

This commit is contained in:
bigcat88
2025-08-19 19:56:59 +03:00
parent c708d0a433
commit f92307cd4c
22 changed files with 1650 additions and 977 deletions

0
app/database/__init__.py Normal file
View File

View File

@@ -1,112 +1,267 @@
import logging
import os
import shutil
from contextlib import asynccontextmanager
from typing import Optional
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
LOGGER = logging.getLogger(__name__)
# Attempt imports which may not exist in some environments
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
_DB_AVAILABLE = True
ENGINE: AsyncEngine | None = None
SESSION: async_sessionmaker | None = None
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
(
"------------------------------------------------------------------------\n"
f"Error importing DB dependencies: {e}\n"
f"{get_missing_requirements_message()}\n"
"This error is happening because ComfyUI now uses a local database.\n"
"------------------------------------------------------------------------"
).strip()
)
_DB_AVAILABLE = False
ENGINE = None
SESSION = None
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
def dependencies_available() -> bool:
"""Check if DB dependencies are importable."""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
def _root_paths():
"""Resolve alembic.ini and migrations script folder."""
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
return config_path, scripts_path
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
def _absolutize_sqlite_url(db_url: str) -> str:
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
try:
u = make_url(db_url)
except Exception:
return db_url
if not u.drivername.startswith("sqlite"):
return db_url
# Make path absolute if relative
db_path = u.database or ""
if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path)
return str(u)
def _to_sync_driver_url(async_url: str) -> str:
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(async_url)
driver = u.drivername
if driver.startswith("sqlite+aiosqlite"):
u = u.set(drivername="sqlite")
elif driver.startswith("postgresql+asyncpg"):
u = u.set(drivername="postgresql")
else:
raise ValueError(f"Unsupported database URL '{url}'.")
# Generic: strip the async driver part if present
if "+" in driver:
u = u.set(drivername=driver.split("+", 1)[0])
return str(u)
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
"""Return the on-disk path for a SQLite URL, else None."""
try:
u = make_url(sync_url)
except Exception:
return None
config = get_alembic_config()
if not u.drivername.startswith("sqlite"):
return None
return u.database
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
def _get_alembic_config(sync_url: str) -> Config:
"""Prepare Alembic Config with script location and DB URL."""
config_path, scripts_path = _root_paths()
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", sync_url)
return cfg
script = ScriptDirectory.from_config(config)
async def init_db_engine() -> None:
"""Initialize async engine + sessionmaker and run migrations to head.
This must be called once on application startup before any DB usage.
"""
global ENGINE, SESSION
if not dependencies_available():
raise RuntimeError("Database dependencies are not available.")
if ENGINE is not None:
return
raw_url = args.database_url
if not raw_url:
raise RuntimeError("Database URL is not configured.")
# Absolutize SQLite path for async engine
db_url = _absolutize_sqlite_url(raw_url)
# Prepare async engine
connect_args = {}
if db_url.startswith("sqlite"):
connect_args = {
"check_same_thread": False,
"timeout": 12,
}
ENGINE = create_async_engine(
db_url,
connect_args=connect_args,
pool_pre_ping=True,
future=True,
)
# Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn:
# WAL for concurrency and durability, Foreign Keys for referential integrity
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
if str(current_mode).lower() != "wal":
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
if str(new_mode).lower() != "wal":
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(raw_url=db_url)
SESSION = async_sessionmaker(
bind=ENGINE,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async def _run_migrations(raw_url: str) -> None:
"""
Run Alembic migrations up to head.
We deliberately use a synchronous engine for migrations because Alembic's
programmatic API is synchronous by default and this path is robust.
"""
# Convert to sync URL and make SQLite URL an absolute one
sync_url = _to_sync_driver_url(raw_url)
sync_url = _absolutize_sqlite_url(sync_url)
cfg = _get_alembic_config(sync_url)
# Inspect current and target heads
engine = create_engine(sync_url, future=True)
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(cfg)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
LOGGER.warning("Alembic: no target revision found.")
return
if current_rev == target_rev:
LOGGER.debug("Alembic: database already at head %s", target_rev)
return
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
# Optional backup for SQLite file DBs
backup_path = None
sqlite_path = _get_sqlite_file_path(sync_url)
if sqlite_path and os.path.exists(sqlite_path):
backup_path = sqlite_path + ".bkp"
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
shutil.copy(sqlite_path, backup_path)
except Exception as exc:
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
try:
command.upgrade(cfg, target_rev)
except Exception:
if backup_path and os.path.exists(backup_path):
LOGGER.exception("Error upgrading database, attempting restore from backup.")
try:
shutil.copy(backup_path, sqlite_path) # restore
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
except Exception as re:
LOGGER.error("Failed to restore SQLite backup: %s", re)
else:
LOGGER.exception("Error upgrading database, backup is not available.")
raise
def create_session():
return Session()
def get_engine():
"""Return the global async engine (initialized after init_db_engine())."""
if ENGINE is None:
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
return ENGINE
def get_session_maker():
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
if SESSION is None:
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
return SESSION
@asynccontextmanager
async def session_scope():
"""Async context manager for a unit of work:
async with session_scope() as sess:
... use sess ...
"""
maker = get_session_maker()
async with maker() as sess:
try:
yield sess
await sess.commit()
except Exception:
await sess.rollback()
raise
async def create_session():
"""Convenience helper to acquire a single AsyncSession instance.
Typical usage:
async with (await create_session()) as sess:
...
"""
maker = get_session_maker()
return maker()

View File

@@ -1,59 +1,257 @@
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import (
Column,
Integer,
Text,
BigInteger,
DateTime,
ForeignKey,
Index,
JSON,
String,
Text,
CheckConstraint,
Numeric,
Boolean,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
def to_dict(obj):
class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
class Model(Base):
"""
sqlalchemy model representing a model file in the system.
class Asset(Base):
__tablename__ = "assets"
This class defines the database schema for storing information about model files,
including their type, path, hash, and when they were added to the system.
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
Attributes:
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
path (Text): The file path of the model relative to the type folder (primary key)
file_name (Text): The name of the model file
file_size (Integer): The size of the model file in bytes
hash (Text): A hash of the model file
hash_algorithm (Text): The algorithm used to generate the hash
source_url (Text): The URL of the model file
date_added (DateTime): Timestamp of when the model was added to the system
"""
infos: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
foreign_keys=lambda: [AssetInfo.asset_hash],
cascade="all,delete-orphan",
passive_deletes=True,
)
__tablename__ = "model"
preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
foreign_keys=lambda: [AssetInfo.preview_hash],
viewonly=True,
)
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
file_name = Column(Text)
file_size = Column(Integer)
hash = Column(Text)
hash_algorithm = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
locator_state: Mapped["AssetLocatorState | None"] = relationship(
back_populates="asset",
uselist=False,
cascade="all, delete-orphan",
passive_deletes=True,
)
def to_dict(self):
"""
Convert the model instance to a dictionary representation.
__table_args__ = (
Index("ix_assets_mime_type", "mime_type"),
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
)
Returns:
dict: A dictionary containing the attributes of the model
"""
dict = to_dict(self)
return dict
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
class AssetLocatorState(Base):
__tablename__ = "asset_locator_state"
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
)
# For fs backends: nanosecond mtime; nullable if not applicable
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
# For HTTP/S3/GCS/Azure, etc.: optional validators
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
__table_args__ = (
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
owner_id: Mapped[str | None] = mapped_column(String(128))
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
)
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
# Relationships
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_hash],
)
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_hash],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="joined",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_hash", "asset_hash"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
{"sqlite_autoincrement": True},
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[int] = mapped_column(
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[int] = mapped_column(
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_by: Mapped[str | None] = mapped_column(String(128))
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(128), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list["AssetInfo"]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

683
app/database/services.py Normal file
View File

@@ -0,0 +1,683 @@
import os
import logging
from collections import defaultdict
from datetime import datetime, timezone
from decimal import Decimal
from typing import Any, Sequence, Optional, Iterable
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, exists, func
from sqlalchemy.orm import contains_eager
from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
async def check_fs_asset_exists_quick(
session,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
AND (if provided) mtime_ns matches stored locator-state,
AND (if provided) size_bytes matches verified size when known.
"""
locator = os.path.abspath(file_path)
stmt = select(sa.literal(True)).select_from(Asset)
conditions = [
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
]
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
# If verified size is 0 (unknown), we don't force equality.
if size_bytes is not None:
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
# If mtime_ns provided require the locator-state to exist and match.
if mtime_ns is not None:
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
stmt = stmt.where(*conditions).limit(1)
row = (await session.execute(stmt)).first()
return row is not None
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: Optional[str] = None,
preview_hash: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
added_by: Optional[str] = None,
require_existing_tags: bool = False,
) -> dict:
"""
Creates or updates Asset record for a local (fs) asset.
Always:
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
- Insert AssetLocatorState if missing; else update mtime_ns if different.
Optionally (when info_name is provided):
- Create an AssetInfo (no refcount changes).
- Link provided tags to that AssetInfo.
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
* If False (default), silently skips unknown tags.
Returns flags and ids:
{
"asset_created": bool,
"asset_updated": bool,
"state_created": bool,
"state_updated": bool,
"asset_info_id": int | None,
"tags_added": list[str],
"tags_missing": list[str], # filled only when require_existing_tags=False
}
"""
locator = os.path.abspath(abs_path)
datetime_now = datetime.now(timezone.utc)
out = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
"tags_added": [],
"tags_missing": [],
}
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
async with session.begin_nested() as sp1:
try:
session.add(
Asset(
hash=asset_hash,
size_bytes=int(size_bytes),
mime_type=mime_type,
refcount=0,
storage_backend="fs",
storage_locator=locator,
created_at=datetime_now,
updated_at=datetime_now,
)
)
await session.flush()
out["asset_created"] = True
except IntegrityError:
await sp1.rollback()
# Already exists by hash -> update selected fields if different
existing = await session.get(Asset, asset_hash)
if existing is not None:
desired_size = int(size_bytes)
if existing.size_bytes != desired_size:
existing.size_bytes = desired_size
existing.updated_at = datetime_now
out["asset_updated"] = True
else:
# This should not occur. Log for visibility.
logging.error("Asset %s not found after conflict; skipping update.", asset_hash)
except Exception:
await sp1.rollback()
logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator)
raise
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
async with session.begin_nested() as sp2:
try:
session.add(
AssetLocatorState(
asset_hash=asset_hash,
mtime_ns=int(mtime_ns),
)
)
await session.flush()
out["state_created"] = True
except IntegrityError:
await sp2.rollback()
state = await session.get(AssetLocatorState, asset_hash)
if state is not None:
desired_mtime = int(mtime_ns)
if state.mtime_ns != desired_mtime:
state.mtime_ns = desired_mtime
out["state_updated"] = True
else:
logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash)
except Exception:
await sp2.rollback()
logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash)
raise
# ---- Optional: AssetInfo + tag links ----
if info_name:
# 2a) Create AssetInfo (no refcount bump)
async with session.begin_nested() as sp3:
try:
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_hash=asset_hash,
preview_hash=preview_hash,
created_at=datetime_now,
updated_at=datetime_now,
last_access_time=datetime_now,
)
session.add(info)
await session.flush() # get info.id
out["asset_info_id"] = info.id
except Exception:
await sp3.rollback()
logging.exception(
"Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name
)
raise
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
# Which tags exist?
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
# Which links already exist?
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_by=added_by,
added_at=datetime_now,
)
for t in to_add
]
)
await session.flush()
out["tags_added"] = to_add
out["tags_missing"] = missing
# 2c) Rebuild metadata projection if provided
if user_metadata is not None and out["asset_info_id"] is not None:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=user_metadata,
)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
abs_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> int:
locator = os.path.abspath(abs_path)
ts = ts or datetime.now(timezone.utc)
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(Asset)
.where(
Asset.hash == AssetInfo.asset_hash,
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
stmt = stmt.values(last_access_time=ts)
res = await session.execute(stmt)
return int(res.rowcount or 0)
async def list_asset_infos_page(
session: AsyncSession,
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
"""
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
plus a map of asset_info_id -> [tags], and the total count.
We purposely collect tags in a separate (single) query to avoid row explosion.
"""
# Clamp
if limit <= 0:
limit = 1
if limit > 100:
limit = 100
if offset < 0:
offset = 0
# Build base query
base = (
select(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.options(contains_eager(AssetInfo.asset))
)
# Filters
if name_contains:
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
# Sort
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
# Total count (same filters, no ordering/limit/offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
)
if name_contains:
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = (await session.execute(count_stmt)).scalar_one()
# Fetch rows
infos = (await session.execute(base)).scalars().unique().all()
# Collect tags in bulk (single query)
id_list = [i.id for i in infos]
tag_map: dict[int, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: int,
tags: Sequence[str],
origin: str = "manual",
added_by: Optional[str] = None,
) -> dict:
"""
Replace the tag set on an AssetInfo with `tags`. Idempotent.
Creates missing tag names as 'user'.
"""
desired = _normalize_tags(tags)
now = datetime.now(timezone.utc)
# current links
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await _ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: int,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
added_by: Optional[str] = None,
) -> AssetInfo:
"""
Update AssetInfo fields:
- name (if provided)
- user_metadata blob + rebuild projection (if provided)
- replace tags with provided set (if provided)
Returns the updated AssetInfo.
"""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
if user_metadata is not None:
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=user_metadata
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
added_by=added_by,
)
touched = True
if touched and user_metadata is None:
info.updated_at = datetime.now(timezone.utc)
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: int,
user_metadata: dict | None,
) -> None:
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = datetime.now(timezone.utc)
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in _project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
wanted = _normalize_tags(list(names))
if not wanted:
return []
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
by_name = {t.name: t for t in existing}
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
if to_create:
session.add_all(to_create)
await session.flush()
by_name.update({t.name: t for t in to_create})
return [by_name[n] for n in wanted]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None,
exclude_tags: Sequence[str] | None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = _normalize_tags(include_tags)
exclude_tags = _normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None,
) -> sa.sql.Select:
"""Apply metadata filters using the projection table asset_info_meta.
Semantics:
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
- For None: key is missing OR key has explicit null (val_json IS NULL).
- For list values: ANY-of the list elements matches (EXISTS for any).
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
"""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
subquery = (
select(sa.literal(1))
.select_from(AssetInfoMeta)
.where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
.limit(1)
)
return sa.exists(subquery)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
# Missing OR null:
if value is None:
# either: no row for key OR a row for key with explicit null
no_row_for_key = ~sa.exists(
select(sa.literal(1))
.select_from(AssetInfoMeta)
.where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
.limit(1)
)
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
return sa.or_(no_row_for_key, null_row)
# Typed scalar matches:
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
# store as Decimal for equality against NUMERIC(38,10)
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
# Complex: compare JSON (no index, but supported)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
# ANY-of (exists for any element)
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def _is_scalar(v: Any) -> bool:
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _project_kv(key: str, value: Any) -> list[dict]:
"""
Turn a metadata key/value into one or more projection rows:
- scalar -> one row (ordinal=0) in the proper typed column
- list of scalars -> one row per element with ordinal=i
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
- None -> single row with val_json = None
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
"""
rows: list[dict] = []
# None
if value is None:
rows.append({"key": key, "ordinal": 0, "val_json": None})
return rows
# Scalars
if _is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
# store numeric; SQLAlchemy will coerce to Numeric
rows.append({"key": key, "ordinal": 0, "val_num": value})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
# Fallback to json
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
# Lists
if isinstance(value, list):
# list of scalars?
if all(_is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append({"key": key, "ordinal": i, "val_json": None})
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
rows.append({"key": key, "ordinal": i, "val_num": x})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# list contains objects -> one val_json per element
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# Dict or any other structure -> single json row
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows