dev: Everything is Assets

This commit is contained in:
bigcat88
2025-08-19 19:56:59 +03:00
parent c708d0a433
commit f92307cd4c
22 changed files with 1650 additions and 977 deletions

0
app/api/__init__.py Normal file
View File

110
app/api/assets_routes.py Normal file
View File

@@ -0,0 +1,110 @@
import json
from typing import Sequence
from aiohttp import web
from app import assets_manager
ROUTES = web.RouteTableDef()
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
q = request.rel_url.query
include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags"))
exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags"))
name_contains = q.get("name_contains")
# Optional JSON metadata filter (top-level key equality only for now)
metadata_filter = None
raw_meta = q.get("metadata_filter")
if raw_meta:
try:
metadata_filter = json.loads(raw_meta)
if not isinstance(metadata_filter, dict):
metadata_filter = None
except Exception:
# Silently ignore malformed JSON for first iteration; could 400 in future
metadata_filter = None
limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100)
offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000)
sort = q.get("sort", "created_at")
order = q.get("order", "desc")
payload = await assets_manager.list_assets(
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
return web.json_response(payload)
@ROUTES.put("/api/assets/{id}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id_raw = request.match_info.get("id")
try:
asset_info_id = int(asset_info_id_raw)
except Exception:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try:
payload = await request.json()
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
name = payload.get("name", None)
tags = payload.get("tags", None)
user_metadata = payload.get("user_metadata", None)
if name is None and tags is None and user_metadata is None:
return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.")
if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)):
return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.")
if user_metadata is not None and not isinstance(user_metadata, dict):
return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.")
try:
result = await assets_manager.update_asset(
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result, status=200)
def register_assets_routes(app: web.Application) -> None:
app.add_routes(ROUTES)
def _parse_csv_tags(raw: str | None) -> list[str]:
if not raw:
return []
return [t.strip() for t in raw.split(",") if t.strip()]
def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int:
if not qval:
return default
try:
v = int(qval)
except Exception:
return default
return max(lo, min(hi, v))
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)

148
app/assets_manager.py Normal file
View File

@@ -0,0 +1,148 @@
import os
from datetime import datetime, timezone
from typing import Optional, Sequence
from .database.db import create_session
from .storage import hashing
from .database.services import (
check_fs_asset_exists_quick,
ingest_fs_asset,
touch_asset_infos_by_fs_path,
list_asset_infos_page,
update_asset_info_full,
get_asset_tags,
)
def get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
Notes:
- Uses absolute path as the canonical locator for the 'fs' backend.
- Computes BLAKE3 only when the fast existence check indicates it's needed.
- This function ensures the identity row and seeds mtime in asset_locator_state.
"""
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = get_size_mtime_ns(abs_path)
if not size_bytes:
return
async with await create_session() as session:
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc))
await session.commit()
return
asset_hash = hashing.blake3_hash_sync(abs_path)
async with await create_session() as session:
await ingest_fs_asset(
session,
asset_hash="blake3:" + asset_hash,
abs_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=file_name,
tag_origin="automatic",
tags=tags,
)
await session.commit()
async def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str | None = "created_at",
order: str | None = "desc",
) -> dict:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page(
session,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
assets_json = []
for info in infos:
asset = info.asset # populated via contains_eager
tags = tag_map.get(info.id, [])
assets_json.append(
{
"id": info.id,
"name": info.name,
"asset_hash": info.asset_hash,
"size": int(asset.size_bytes) if asset else None,
"mime_type": asset.mime_type if asset else None,
"tags": tags,
"preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later
"created_at": info.created_at.isoformat() if info.created_at else None,
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
"last_access_time": info.last_access_time.isoformat() if info.last_access_time else None,
}
)
return {
"assets": assets_json,
"total": total,
"has_more": (offset + len(assets_json)) < total,
}
async def update_asset(
*,
asset_info_id: int,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
) -> dict:
async with await create_session() as session:
info = await update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
added_by=None,
)
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
await session.commit()
return {
"id": info.id,
"name": info.name,
"asset_hash": info.asset_hash,
"tags": tag_names,
"user_metadata": info.user_metadata or {},
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
}
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"

0
app/database/__init__.py Normal file
View File

View File

@@ -1,112 +1,267 @@
import logging
import os
import shutil
from contextlib import asynccontextmanager
from typing import Optional
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
LOGGER = logging.getLogger(__name__)
# Attempt imports which may not exist in some environments
try:
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
_DB_AVAILABLE = True
ENGINE: AsyncEngine | None = None
SESSION: async_sessionmaker | None = None
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
(
"------------------------------------------------------------------------\n"
f"Error importing DB dependencies: {e}\n"
f"{get_missing_requirements_message()}\n"
"This error is happening because ComfyUI now uses a local database.\n"
"------------------------------------------------------------------------"
).strip()
)
_DB_AVAILABLE = False
ENGINE = None
SESSION = None
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
def dependencies_available() -> bool:
"""Check if DB dependencies are importable."""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
def _root_paths():
"""Resolve alembic.ini and migrations script folder."""
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
return config_path, scripts_path
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
def _absolutize_sqlite_url(db_url: str) -> str:
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
try:
u = make_url(db_url)
except Exception:
return db_url
if not u.drivername.startswith("sqlite"):
return db_url
# Make path absolute if relative
db_path = u.database or ""
if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path)
return str(u)
def _to_sync_driver_url(async_url: str) -> str:
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(async_url)
driver = u.drivername
if driver.startswith("sqlite+aiosqlite"):
u = u.set(drivername="sqlite")
elif driver.startswith("postgresql+asyncpg"):
u = u.set(drivername="postgresql")
else:
raise ValueError(f"Unsupported database URL '{url}'.")
# Generic: strip the async driver part if present
if "+" in driver:
u = u.set(drivername=driver.split("+", 1)[0])
return str(u)
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
"""Return the on-disk path for a SQLite URL, else None."""
try:
u = make_url(sync_url)
except Exception:
return None
config = get_alembic_config()
if not u.drivername.startswith("sqlite"):
return None
return u.database
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
def _get_alembic_config(sync_url: str) -> Config:
"""Prepare Alembic Config with script location and DB URL."""
config_path, scripts_path = _root_paths()
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", sync_url)
return cfg
script = ScriptDirectory.from_config(config)
async def init_db_engine() -> None:
"""Initialize async engine + sessionmaker and run migrations to head.
This must be called once on application startup before any DB usage.
"""
global ENGINE, SESSION
if not dependencies_available():
raise RuntimeError("Database dependencies are not available.")
if ENGINE is not None:
return
raw_url = args.database_url
if not raw_url:
raise RuntimeError("Database URL is not configured.")
# Absolutize SQLite path for async engine
db_url = _absolutize_sqlite_url(raw_url)
# Prepare async engine
connect_args = {}
if db_url.startswith("sqlite"):
connect_args = {
"check_same_thread": False,
"timeout": 12,
}
ENGINE = create_async_engine(
db_url,
connect_args=connect_args,
pool_pre_ping=True,
future=True,
)
# Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn:
# WAL for concurrency and durability, Foreign Keys for referential integrity
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
if str(current_mode).lower() != "wal":
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
if str(new_mode).lower() != "wal":
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(raw_url=db_url)
SESSION = async_sessionmaker(
bind=ENGINE,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async def _run_migrations(raw_url: str) -> None:
"""
Run Alembic migrations up to head.
We deliberately use a synchronous engine for migrations because Alembic's
programmatic API is synchronous by default and this path is robust.
"""
# Convert to sync URL and make SQLite URL an absolute one
sync_url = _to_sync_driver_url(raw_url)
sync_url = _absolutize_sqlite_url(sync_url)
cfg = _get_alembic_config(sync_url)
# Inspect current and target heads
engine = create_engine(sync_url, future=True)
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(cfg)
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
LOGGER.warning("Alembic: no target revision found.")
return
if current_rev == target_rev:
LOGGER.debug("Alembic: database already at head %s", target_rev)
return
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
# Optional backup for SQLite file DBs
backup_path = None
sqlite_path = _get_sqlite_file_path(sync_url)
if sqlite_path and os.path.exists(sqlite_path):
backup_path = sqlite_path + ".bkp"
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
shutil.copy(sqlite_path, backup_path)
except Exception as exc:
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
try:
command.upgrade(cfg, target_rev)
except Exception:
if backup_path and os.path.exists(backup_path):
LOGGER.exception("Error upgrading database, attempting restore from backup.")
try:
shutil.copy(backup_path, sqlite_path) # restore
os.remove(backup_path)
logging.exception("Error upgrading database: ")
raise e
global Session
Session = sessionmaker(bind=engine)
except Exception as re:
LOGGER.error("Failed to restore SQLite backup: %s", re)
else:
LOGGER.exception("Error upgrading database, backup is not available.")
raise
def create_session():
return Session()
def get_engine():
"""Return the global async engine (initialized after init_db_engine())."""
if ENGINE is None:
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
return ENGINE
def get_session_maker():
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
if SESSION is None:
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
return SESSION
@asynccontextmanager
async def session_scope():
"""Async context manager for a unit of work:
async with session_scope() as sess:
... use sess ...
"""
maker = get_session_maker()
async with maker() as sess:
try:
yield sess
await sess.commit()
except Exception:
await sess.rollback()
raise
async def create_session():
"""Convenience helper to acquire a single AsyncSession instance.
Typical usage:
async with (await create_session()) as sess:
...
"""
maker = get_session_maker()
return maker()

View File

@@ -1,59 +1,257 @@
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import (
Column,
Integer,
Text,
BigInteger,
DateTime,
ForeignKey,
Index,
JSON,
String,
Text,
CheckConstraint,
Numeric,
Boolean,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
def to_dict(obj):
class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
class Model(Base):
"""
sqlalchemy model representing a model file in the system.
class Asset(Base):
__tablename__ = "assets"
This class defines the database schema for storing information about model files,
including their type, path, hash, and when they were added to the system.
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
Attributes:
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
path (Text): The file path of the model relative to the type folder (primary key)
file_name (Text): The name of the model file
file_size (Integer): The size of the model file in bytes
hash (Text): A hash of the model file
hash_algorithm (Text): The algorithm used to generate the hash
source_url (Text): The URL of the model file
date_added (DateTime): Timestamp of when the model was added to the system
"""
infos: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
foreign_keys=lambda: [AssetInfo.asset_hash],
cascade="all,delete-orphan",
passive_deletes=True,
)
__tablename__ = "model"
preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
foreign_keys=lambda: [AssetInfo.preview_hash],
viewonly=True,
)
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
file_name = Column(Text)
file_size = Column(Integer)
hash = Column(Text)
hash_algorithm = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
locator_state: Mapped["AssetLocatorState | None"] = relationship(
back_populates="asset",
uselist=False,
cascade="all, delete-orphan",
passive_deletes=True,
)
def to_dict(self):
"""
Convert the model instance to a dictionary representation.
__table_args__ = (
Index("ix_assets_mime_type", "mime_type"),
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
)
Returns:
dict: A dictionary containing the attributes of the model
"""
dict = to_dict(self)
return dict
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
class AssetLocatorState(Base):
__tablename__ = "asset_locator_state"
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
)
# For fs backends: nanosecond mtime; nullable if not applicable
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
# For HTTP/S3/GCS/Azure, etc.: optional validators
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
__table_args__ = (
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
owner_id: Mapped[str | None] = mapped_column(String(128))
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
)
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
# Relationships
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_hash],
)
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_hash],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="joined",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_hash", "asset_hash"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
{"sqlite_autoincrement": True},
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[int] = mapped_column(
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[int] = mapped_column(
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_by: Mapped[str | None] = mapped_column(String(128))
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(128), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list["AssetInfo"]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

683
app/database/services.py Normal file
View File

@@ -0,0 +1,683 @@
import os
import logging
from collections import defaultdict
from datetime import datetime, timezone
from decimal import Decimal
from typing import Any, Sequence, Optional, Iterable
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, exists, func
from sqlalchemy.orm import contains_eager
from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
async def check_fs_asset_exists_quick(
session,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
AND (if provided) mtime_ns matches stored locator-state,
AND (if provided) size_bytes matches verified size when known.
"""
locator = os.path.abspath(file_path)
stmt = select(sa.literal(True)).select_from(Asset)
conditions = [
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
]
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
# If verified size is 0 (unknown), we don't force equality.
if size_bytes is not None:
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
# If mtime_ns provided require the locator-state to exist and match.
if mtime_ns is not None:
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
stmt = stmt.where(*conditions).limit(1)
row = (await session.execute(stmt)).first()
return row is not None
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: Optional[str] = None,
preview_hash: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
added_by: Optional[str] = None,
require_existing_tags: bool = False,
) -> dict:
"""
Creates or updates Asset record for a local (fs) asset.
Always:
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
- Insert AssetLocatorState if missing; else update mtime_ns if different.
Optionally (when info_name is provided):
- Create an AssetInfo (no refcount changes).
- Link provided tags to that AssetInfo.
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
* If False (default), silently skips unknown tags.
Returns flags and ids:
{
"asset_created": bool,
"asset_updated": bool,
"state_created": bool,
"state_updated": bool,
"asset_info_id": int | None,
"tags_added": list[str],
"tags_missing": list[str], # filled only when require_existing_tags=False
}
"""
locator = os.path.abspath(abs_path)
datetime_now = datetime.now(timezone.utc)
out = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
"tags_added": [],
"tags_missing": [],
}
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
async with session.begin_nested() as sp1:
try:
session.add(
Asset(
hash=asset_hash,
size_bytes=int(size_bytes),
mime_type=mime_type,
refcount=0,
storage_backend="fs",
storage_locator=locator,
created_at=datetime_now,
updated_at=datetime_now,
)
)
await session.flush()
out["asset_created"] = True
except IntegrityError:
await sp1.rollback()
# Already exists by hash -> update selected fields if different
existing = await session.get(Asset, asset_hash)
if existing is not None:
desired_size = int(size_bytes)
if existing.size_bytes != desired_size:
existing.size_bytes = desired_size
existing.updated_at = datetime_now
out["asset_updated"] = True
else:
# This should not occur. Log for visibility.
logging.error("Asset %s not found after conflict; skipping update.", asset_hash)
except Exception:
await sp1.rollback()
logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator)
raise
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
async with session.begin_nested() as sp2:
try:
session.add(
AssetLocatorState(
asset_hash=asset_hash,
mtime_ns=int(mtime_ns),
)
)
await session.flush()
out["state_created"] = True
except IntegrityError:
await sp2.rollback()
state = await session.get(AssetLocatorState, asset_hash)
if state is not None:
desired_mtime = int(mtime_ns)
if state.mtime_ns != desired_mtime:
state.mtime_ns = desired_mtime
out["state_updated"] = True
else:
logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash)
except Exception:
await sp2.rollback()
logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash)
raise
# ---- Optional: AssetInfo + tag links ----
if info_name:
# 2a) Create AssetInfo (no refcount bump)
async with session.begin_nested() as sp3:
try:
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_hash=asset_hash,
preview_hash=preview_hash,
created_at=datetime_now,
updated_at=datetime_now,
last_access_time=datetime_now,
)
session.add(info)
await session.flush() # get info.id
out["asset_info_id"] = info.id
except Exception:
await sp3.rollback()
logging.exception(
"Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name
)
raise
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
# Which tags exist?
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
# Which links already exist?
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_by=added_by,
added_at=datetime_now,
)
for t in to_add
]
)
await session.flush()
out["tags_added"] = to_add
out["tags_missing"] = missing
# 2c) Rebuild metadata projection if provided
if user_metadata is not None and out["asset_info_id"] is not None:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=user_metadata,
)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
abs_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> int:
locator = os.path.abspath(abs_path)
ts = ts or datetime.now(timezone.utc)
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(Asset)
.where(
Asset.hash == AssetInfo.asset_hash,
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
stmt = stmt.values(last_access_time=ts)
res = await session.execute(stmt)
return int(res.rowcount or 0)
async def list_asset_infos_page(
session: AsyncSession,
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
"""
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
plus a map of asset_info_id -> [tags], and the total count.
We purposely collect tags in a separate (single) query to avoid row explosion.
"""
# Clamp
if limit <= 0:
limit = 1
if limit > 100:
limit = 100
if offset < 0:
offset = 0
# Build base query
base = (
select(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.options(contains_eager(AssetInfo.asset))
)
# Filters
if name_contains:
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
# Sort
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
# Total count (same filters, no ordering/limit/offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
)
if name_contains:
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = (await session.execute(count_stmt)).scalar_one()
# Fetch rows
infos = (await session.execute(base)).scalars().unique().all()
# Collect tags in bulk (single query)
id_list = [i.id for i in infos]
tag_map: dict[int, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: int,
tags: Sequence[str],
origin: str = "manual",
added_by: Optional[str] = None,
) -> dict:
"""
Replace the tag set on an AssetInfo with `tags`. Idempotent.
Creates missing tag names as 'user'.
"""
desired = _normalize_tags(tags)
now = datetime.now(timezone.utc)
# current links
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await _ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: int,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
added_by: Optional[str] = None,
) -> AssetInfo:
"""
Update AssetInfo fields:
- name (if provided)
- user_metadata blob + rebuild projection (if provided)
- replace tags with provided set (if provided)
Returns the updated AssetInfo.
"""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
if user_metadata is not None:
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=user_metadata
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
added_by=added_by,
)
touched = True
if touched and user_metadata is None:
info.updated_at = datetime.now(timezone.utc)
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: int,
user_metadata: dict | None,
) -> None:
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = datetime.now(timezone.utc)
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in _project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
wanted = _normalize_tags(list(names))
if not wanted:
return []
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
by_name = {t.name: t for t in existing}
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
if to_create:
session.add_all(to_create)
await session.flush()
by_name.update({t.name: t for t in to_create})
return [by_name[n] for n in wanted]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None,
exclude_tags: Sequence[str] | None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = _normalize_tags(include_tags)
exclude_tags = _normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None,
) -> sa.sql.Select:
"""Apply metadata filters using the projection table asset_info_meta.
Semantics:
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
- For None: key is missing OR key has explicit null (val_json IS NULL).
- For list values: ANY-of the list elements matches (EXISTS for any).
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
"""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
subquery = (
select(sa.literal(1))
.select_from(AssetInfoMeta)
.where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
.limit(1)
)
return sa.exists(subquery)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
# Missing OR null:
if value is None:
# either: no row for key OR a row for key with explicit null
no_row_for_key = ~sa.exists(
select(sa.literal(1))
.select_from(AssetInfoMeta)
.where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
.limit(1)
)
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
return sa.or_(no_row_for_key, null_row)
# Typed scalar matches:
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
# store as Decimal for equality against NUMERIC(38,10)
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
# Complex: compare JSON (no index, but supported)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
# ANY-of (exists for any element)
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def _is_scalar(v: Any) -> bool:
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _project_kv(key: str, value: Any) -> list[dict]:
"""
Turn a metadata key/value into one or more projection rows:
- scalar -> one row (ordinal=0) in the proper typed column
- list of scalars -> one row per element with ordinal=i
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
- None -> single row with val_json = None
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
"""
rows: list[dict] = []
# None
if value is None:
rows.append({"key": key, "ordinal": 0, "val_json": None})
return rows
# Scalars
if _is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
# store numeric; SQLAlchemy will coerce to Numeric
rows.append({"key": key, "ordinal": 0, "val_num": value})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
# Fallback to json
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
# Lists
if isinstance(value, list):
# list of scalars?
if all(_is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append({"key": key, "ordinal": i, "val_json": None})
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
rows.append({"key": key, "ordinal": i, "val_num": x})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# list contains objects -> one val_json per element
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# Dict or any other structure -> single json row
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -1,331 +0,0 @@
import os
import logging
import time
import requests
from tqdm import tqdm
from folder_paths import get_relative_path, get_full_path
from app.database.db import create_session, dependencies_available, can_create_session
import blake3
import comfy.utils
if dependencies_available():
from app.database.models import Model
class ModelProcessor:
def _validate_path(self, model_path):
try:
if not self._file_exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
result = get_relative_path(model_path)
if not result:
logging.error(
f"Model file not in a recognized model directory: {model_path}"
)
return None
return result
except Exception as e:
logging.error(f"Error validating model path {model_path}: {str(e)}")
return None
def _file_exists(self, path):
"""Check if a file exists."""
return os.path.exists(path)
def _get_file_size(self, path):
"""Get file size."""
return os.path.getsize(path)
def _get_hasher(self):
return blake3.blake3()
def _hash_file(self, model_path):
try:
hasher = self._get_hasher()
with open(model_path, "rb", buffering=0) as f:
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
except Exception as e:
logging.error(f"Error hashing file {model_path}: {str(e)}")
return None
def _get_existing_model(self, session, model_type, model_relative_path):
return (
session.query(Model)
.filter(Model.type == model_type)
.filter(Model.path == model_relative_path)
.first()
)
def _ensure_source_url(self, session, model, source_url):
if model.source_url is None:
model.source_url = source_url
session.commit()
def _update_database(
self,
session,
model_type,
model_path,
model_relative_path,
model_hash,
model,
source_url,
):
try:
if not model:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if not model:
model = Model(
path=model_relative_path,
type=model_type,
file_name=os.path.basename(model_path),
)
session.add(model)
model.file_size = self._get_file_size(model_path)
model.hash = model_hash
if model_hash:
model.hash_algorithm = "blake3"
model.source_url = source_url
session.commit()
return model
except Exception as e:
logging.error(
f"Error updating database for {model_relative_path}: {str(e)}"
)
def process_file(self, model_path, source_url=None, model_hash=None):
"""
Process a model file and update the database with metadata.
If the file already exists and matches the database, it will not be processed again.
Returns the model object or if an error occurs, returns None.
"""
try:
if not can_create_session():
return
result = self._validate_path(model_path)
if not result:
return
model_type, model_relative_path = result
with create_session() as session:
session.expire_on_commit = False
existing_model = self._get_existing_model(
session, model_type, model_relative_path
)
if (
existing_model
and existing_model.hash
and existing_model.file_size == self._get_file_size(model_path)
):
# File exists with hash and same size, no need to process
self._ensure_source_url(session, existing_model, source_url)
return existing_model
if model_hash:
model_hash = model_hash.lower()
logging.info(f"Using provided hash: {model_hash}")
else:
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
if not model_hash:
return
logging.info(
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
return self._update_database(
session,
model_type,
model_path,
model_relative_path,
model_hash,
existing_model,
source_url,
)
except Exception as e:
logging.error(f"Error processing model file {model_path}: {str(e)}")
return None
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
"""
Retrieve a model file from the database by hash and optionally by model type.
Returns the model object or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
dispose_session = False
if session is None:
session = create_session()
dispose_session = True
model = session.query(Model).filter(Model.hash == model_hash)
if model_type is not None:
model = model.filter(Model.type == model_type)
return model.first()
except Exception as e:
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
return None
finally:
if dispose_session:
session.close()
def retrieve_hash(self, model_path, model_type=None):
"""
Retrieve the hash of a model file from the database.
Returns the hash or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
if model_type is not None:
result = self._validate_path(model_path)
if not result:
return None
model_type, model_relative_path = result
with create_session() as session:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if model and model.hash:
return model.hash
return None
except Exception as e:
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
return None
def _validate_file_extension(self, file_name):
"""Validate that the file extension is supported."""
extension = os.path.splitext(file_name)[1]
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
def _check_existing_file(self, model_type, file_name, expected_hash):
"""Check if file exists and has correct hash."""
destination_path = get_full_path(model_type, file_name, allow_missing=True)
if self._file_exists(destination_path):
model = self.process_file(destination_path)
if model and (expected_hash is None or model.hash == expected_hash):
logging.debug(
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
)
return destination_path
else:
raise ValueError(
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
)
return None
def _check_existing_file_by_hash(self, hash, type, url):
"""Check if a file with the given hash exists in the database and on disk."""
hash = hash.lower()
with create_session() as session:
model = self.retrieve_model_by_hash(hash, type, session)
if model:
existing_path = get_full_path(type, model.path)
if existing_path:
logging.debug(
f"File {model.path} already exists in the database at {existing_path}"
)
self._ensure_source_url(session, model, url)
return existing_path
else:
logging.debug(
f"File {model.path} exists in the database but not on disk"
)
return None
def _download_file(self, url, destination_path, hasher):
"""Download a file and update the hasher with its contents."""
response = requests.get(url, stream=True)
logging.info(f"Downloading {url} to {destination_path}")
with open(destination_path, "wb") as f:
total_size = int(response.headers.get("content-length", 0))
if total_size > 0:
pbar = comfy.utils.ProgressBar(total_size)
else:
pbar = None
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
for chunk in response.iter_content(chunk_size=128 * 1024):
if chunk:
f.write(chunk)
hasher.update(chunk)
progress_bar.update(len(chunk))
if pbar:
pbar.update(len(chunk))
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
"""Verify that the downloaded file has the expected hash."""
if expected_hash is not None and calculated_hash != expected_hash:
self._remove_file(destination_path)
raise ValueError(
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
)
def _remove_file(self, file_path):
"""Remove a file from disk."""
os.remove(file_path)
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
"""
Ensure a model file is downloaded and has the correct hash.
Returns the path to the downloaded file.
"""
logging.debug(
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
)
# Validate file extension
self._validate_file_extension(desired_file_name)
# Check if file exists with correct hash
if hash:
existing_path = self._check_existing_file_by_hash(hash, type, url)
if existing_path:
return existing_path
# Check if file exists locally
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
existing_path = self._check_existing_file(type, desired_file_name, hash)
if existing_path:
return existing_path
# Download the file
hasher = self._get_hasher()
self._download_file(url, destination_path, hasher)
# Verify hash
calculated_hash = hasher.hexdigest()
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
# Update database
self.process_file(destination_path, url, calculated_hash)
# TODO: Notify frontend to reload models
return destination_path
model_processor = ModelProcessor()

0
app/storage/__init__.py Normal file
View File

72
app/storage/hashing.py Normal file
View File

@@ -0,0 +1,72 @@
import asyncio
import os
from typing import IO, Union
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
"""Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = None
if hasattr(file_obj, "tell"):
orig_pos = file_obj.tell()
try:
if hasattr(file_obj, "seek"):
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if hasattr(file_obj, "seek") and orig_pos is not None:
file_obj.seek(orig_pos)
def blake3_hash_sync(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
if hasattr(fp, "read"):
return _hash_file_obj_sync(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
async def blake3_hash(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
return await asyncio.to_thread(_worker)