mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-27 10:24:06 +00:00
dev: Everything is Assets
This commit is contained in:
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
110
app/api/assets_routes.py
Normal file
110
app/api/assets_routes.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import json
|
||||
from typing import Sequence
|
||||
from aiohttp import web
|
||||
|
||||
from app import assets_manager
|
||||
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
q = request.rel_url.query
|
||||
|
||||
include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags"))
|
||||
exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags"))
|
||||
name_contains = q.get("name_contains")
|
||||
|
||||
# Optional JSON metadata filter (top-level key equality only for now)
|
||||
metadata_filter = None
|
||||
raw_meta = q.get("metadata_filter")
|
||||
if raw_meta:
|
||||
try:
|
||||
metadata_filter = json.loads(raw_meta)
|
||||
if not isinstance(metadata_filter, dict):
|
||||
metadata_filter = None
|
||||
except Exception:
|
||||
# Silently ignore malformed JSON for first iteration; could 400 in future
|
||||
metadata_filter = None
|
||||
|
||||
limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100)
|
||||
offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000)
|
||||
sort = q.get("sort", "created_at")
|
||||
order = q.get("order", "desc")
|
||||
|
||||
payload = await assets_manager.list_assets(
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@ROUTES.put("/api/assets/{id}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id_raw = request.match_info.get("id")
|
||||
try:
|
||||
asset_info_id = int(asset_info_id_raw)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
name = payload.get("name", None)
|
||||
tags = payload.get("tags", None)
|
||||
user_metadata = payload.get("user_metadata", None)
|
||||
|
||||
if name is None and tags is None and user_metadata is None:
|
||||
return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.")
|
||||
|
||||
if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)):
|
||||
return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.")
|
||||
|
||||
if user_metadata is not None and not isinstance(user_metadata, dict):
|
||||
return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result, status=200)
|
||||
|
||||
|
||||
def register_assets_routes(app: web.Application) -> None:
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
def _parse_csv_tags(raw: str | None) -> list[str]:
|
||||
if not raw:
|
||||
return []
|
||||
return [t.strip() for t in raw.split(",") if t.strip()]
|
||||
|
||||
|
||||
def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int:
|
||||
if not qval:
|
||||
return default
|
||||
try:
|
||||
v = int(qval)
|
||||
except Exception:
|
||||
return default
|
||||
return max(lo, min(hi, v))
|
||||
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
148
app/assets_manager.py
Normal file
148
app/assets_manager.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from .database.db import create_session
|
||||
from .storage import hashing
|
||||
from .database.services import (
|
||||
check_fs_asset_exists_quick,
|
||||
ingest_fs_asset,
|
||||
touch_asset_infos_by_fs_path,
|
||||
list_asset_infos_page,
|
||||
update_asset_info_full,
|
||||
get_asset_tags,
|
||||
)
|
||||
|
||||
|
||||
def get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
|
||||
|
||||
Notes:
|
||||
- Uses absolute path as the canonical locator for the 'fs' backend.
|
||||
- Computes BLAKE3 only when the fast existence check indicates it's needed.
|
||||
- This function ensures the identity row and seeds mtime in asset_locator_state.
|
||||
"""
|
||||
abs_path = os.path.abspath(file_path)
|
||||
size_bytes, mtime_ns = get_size_mtime_ns(abs_path)
|
||||
if not size_bytes:
|
||||
return
|
||||
|
||||
async with await create_session() as session:
|
||||
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
|
||||
await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc))
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
asset_hash = hashing.blake3_hash_sync(abs_path)
|
||||
|
||||
async with await create_session() as session:
|
||||
await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash="blake3:" + asset_hash,
|
||||
abs_path=abs_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=None,
|
||||
info_name=file_name,
|
||||
tag_origin="automatic",
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str | None = "created_at",
|
||||
order: str | None = "desc",
|
||||
) -> dict:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
async with await create_session() as session:
|
||||
infos, tag_map, total = await list_asset_infos_page(
|
||||
session,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
assets_json = []
|
||||
for info in infos:
|
||||
asset = info.asset # populated via contains_eager
|
||||
tags = tag_map.get(info.id, [])
|
||||
assets_json.append(
|
||||
{
|
||||
"id": info.id,
|
||||
"name": info.name,
|
||||
"asset_hash": info.asset_hash,
|
||||
"size": int(asset.size_bytes) if asset else None,
|
||||
"mime_type": asset.mime_type if asset else None,
|
||||
"tags": tags,
|
||||
"preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later
|
||||
"created_at": info.created_at.isoformat() if info.created_at else None,
|
||||
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
||||
"last_access_time": info.last_access_time.isoformat() if info.last_access_time else None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"assets": assets_json,
|
||||
"total": total,
|
||||
"has_more": (offset + len(assets_json)) < total,
|
||||
}
|
||||
|
||||
|
||||
async def update_asset(
|
||||
*,
|
||||
asset_info_id: int,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
) -> dict:
|
||||
async with await create_session() as session:
|
||||
info = await update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
added_by=None,
|
||||
)
|
||||
|
||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"id": info.id,
|
||||
"name": info.name,
|
||||
"asset_hash": info.asset_hash,
|
||||
"tags": tag_names,
|
||||
"user_metadata": info.user_metadata or {},
|
||||
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
@@ -1,112 +1,267 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Attempt imports which may not exist in some environments
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
ENGINE: AsyncEngine | None = None
|
||||
SESSION: async_sessionmaker | None = None
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
(
|
||||
"------------------------------------------------------------------------\n"
|
||||
f"Error importing DB dependencies: {e}\n"
|
||||
f"{get_missing_requirements_message()}\n"
|
||||
"This error is happening because ComfyUI now uses a local database.\n"
|
||||
"------------------------------------------------------------------------"
|
||||
).strip()
|
||||
)
|
||||
_DB_AVAILABLE = False
|
||||
ENGINE = None
|
||||
SESSION = None
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
def dependencies_available() -> bool:
|
||||
"""Check if DB dependencies are importable."""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
def _root_paths():
|
||||
"""Resolve alembic.ini and migrations script folder."""
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
return config_path, scripts_path
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
def _absolutize_sqlite_url(db_url: str) -> str:
|
||||
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url
|
||||
|
||||
# Make path absolute if relative
|
||||
db_path = u.database or ""
|
||||
if not os.path.isabs(db_path):
|
||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||
u = u.set(database=db_path)
|
||||
return str(u)
|
||||
|
||||
|
||||
def _to_sync_driver_url(async_url: str) -> str:
|
||||
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||
u = make_url(async_url)
|
||||
driver = u.drivername
|
||||
|
||||
if driver.startswith("sqlite+aiosqlite"):
|
||||
u = u.set(drivername="sqlite")
|
||||
elif driver.startswith("postgresql+asyncpg"):
|
||||
u = u.set(drivername="postgresql")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
# Generic: strip the async driver part if present
|
||||
if "+" in driver:
|
||||
u = u.set(drivername=driver.split("+", 1)[0])
|
||||
|
||||
return str(u)
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||
"""Return the on-disk path for a SQLite URL, else None."""
|
||||
try:
|
||||
u = make_url(sync_url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
config = get_alembic_config()
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return None
|
||||
return u.database
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
def _get_alembic_config(sync_url: str) -> Config:
|
||||
"""Prepare Alembic Config with script location and DB URL."""
|
||||
config_path, scripts_path = _root_paths()
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", sync_url)
|
||||
return cfg
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
|
||||
async def init_db_engine() -> None:
|
||||
"""Initialize async engine + sessionmaker and run migrations to head.
|
||||
|
||||
This must be called once on application startup before any DB usage.
|
||||
"""
|
||||
global ENGINE, SESSION
|
||||
|
||||
if not dependencies_available():
|
||||
raise RuntimeError("Database dependencies are not available.")
|
||||
|
||||
if ENGINE is not None:
|
||||
return
|
||||
|
||||
raw_url = args.database_url
|
||||
if not raw_url:
|
||||
raise RuntimeError("Database URL is not configured.")
|
||||
|
||||
# Absolutize SQLite path for async engine
|
||||
db_url = _absolutize_sqlite_url(raw_url)
|
||||
|
||||
# Prepare async engine
|
||||
connect_args = {}
|
||||
if db_url.startswith("sqlite"):
|
||||
connect_args = {
|
||||
"check_same_thread": False,
|
||||
"timeout": 12,
|
||||
}
|
||||
|
||||
ENGINE = create_async_engine(
|
||||
db_url,
|
||||
connect_args=connect_args,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Enforce SQLite pragmas on the async engine
|
||||
if db_url.startswith("sqlite"):
|
||||
async with ENGINE.begin() as conn:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
|
||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||
|
||||
await _run_migrations(raw_url=db_url)
|
||||
|
||||
SESSION = async_sessionmaker(
|
||||
bind=ENGINE,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def _run_migrations(raw_url: str) -> None:
|
||||
"""
|
||||
Run Alembic migrations up to head.
|
||||
|
||||
We deliberately use a synchronous engine for migrations because Alembic's
|
||||
programmatic API is synchronous by default and this path is robust.
|
||||
"""
|
||||
# Convert to sync URL and make SQLite URL an absolute one
|
||||
sync_url = _to_sync_driver_url(raw_url)
|
||||
sync_url = _absolutize_sqlite_url(sync_url)
|
||||
|
||||
cfg = _get_alembic_config(sync_url)
|
||||
|
||||
# Inspect current and target heads
|
||||
engine = create_engine(sync_url, future=True)
|
||||
with engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
LOGGER.warning("Alembic: no target revision found.")
|
||||
return
|
||||
|
||||
if current_rev == target_rev:
|
||||
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
||||
return
|
||||
|
||||
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
||||
|
||||
# Optional backup for SQLite file DBs
|
||||
backup_path = None
|
||||
sqlite_path = _get_sqlite_file_path(sync_url)
|
||||
if sqlite_path and os.path.exists(sqlite_path):
|
||||
backup_path = sqlite_path + ".bkp"
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
shutil.copy(sqlite_path, backup_path)
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, target_rev)
|
||||
except Exception:
|
||||
if backup_path and os.path.exists(backup_path):
|
||||
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
||||
try:
|
||||
shutil.copy(backup_path, sqlite_path) # restore
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
except Exception as re:
|
||||
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
||||
else:
|
||||
LOGGER.exception("Error upgrading database, backup is not available.")
|
||||
raise
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
def get_engine():
|
||||
"""Return the global async engine (initialized after init_db_engine())."""
|
||||
if ENGINE is None:
|
||||
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
||||
return ENGINE
|
||||
|
||||
|
||||
def get_session_maker():
|
||||
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
||||
if SESSION is None:
|
||||
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
||||
return SESSION
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope():
|
||||
"""Async context manager for a unit of work:
|
||||
|
||||
async with session_scope() as sess:
|
||||
... use sess ...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
async with maker() as sess:
|
||||
try:
|
||||
yield sess
|
||||
await sess.commit()
|
||||
except Exception:
|
||||
await sess.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def create_session():
|
||||
"""Convenience helper to acquire a single AsyncSession instance.
|
||||
|
||||
Typical usage:
|
||||
async with (await create_session()) as sess:
|
||||
...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
return maker()
|
||||
|
||||
@@ -1,59 +1,257 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
Text,
|
||||
BigInteger,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
CheckConstraint,
|
||||
Numeric,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
sqlalchemy model representing a model file in the system.
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
This class defines the database schema for storing information about model files,
|
||||
including their type, path, hash, and when they were added to the system.
|
||||
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
|
||||
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
Attributes:
|
||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
||||
path (Text): The file path of the model relative to the type folder (primary key)
|
||||
file_name (Text): The name of the model file
|
||||
file_size (Integer): The size of the model file in bytes
|
||||
hash (Text): A hash of the model file
|
||||
hash_algorithm (Text): The algorithm used to generate the hash
|
||||
source_url (Text): The URL of the model file
|
||||
date_added (DateTime): Timestamp of when the model was added to the system
|
||||
"""
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
|
||||
foreign_keys=lambda: [AssetInfo.asset_hash],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__tablename__ = "model"
|
||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
|
||||
foreign_keys=lambda: [AssetInfo.preview_hash],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
type = Column(Text, primary_key=True)
|
||||
path = Column(Text, primary_key=True)
|
||||
file_name = Column(Text)
|
||||
file_size = Column(Integer)
|
||||
hash = Column(Text)
|
||||
hash_algorithm = Column(Text)
|
||||
source_url = Column(Text)
|
||||
date_added = Column(DateTime, server_default=func.now())
|
||||
locator_state: Mapped["AssetLocatorState | None"] = relationship(
|
||||
back_populates="asset",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert the model instance to a dictionary representation.
|
||||
__table_args__ = (
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
|
||||
)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the attributes of the model
|
||||
"""
|
||||
dict = to_dict(self)
|
||||
return dict
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
|
||||
|
||||
|
||||
class AssetLocatorState(Base):
|
||||
__tablename__ = "asset_locator_state"
|
||||
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
# For fs backends: nanosecond mtime; nullable if not applicable
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
# For HTTP/S3/GCS/Azure, etc.: optional validators
|
||||
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(128))
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
|
||||
)
|
||||
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
# Relationships
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_hash],
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_hash],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list["Tag"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="joined",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_hash", "asset_hash"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
{"sqlite_autoincrement": True},
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
|
||||
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_by: Mapped[str | None] = mapped_column(String(128))
|
||||
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
683
app/database/services.py
Normal file
683
app/database/services.py
Normal file
@@ -0,0 +1,683 @@
|
||||
import os
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, Sequence, Optional, Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session,
|
||||
*,
|
||||
file_path: str,
|
||||
size_bytes: Optional[int] = None,
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
|
||||
AND (if provided) mtime_ns matches stored locator-state,
|
||||
AND (if provided) size_bytes matches verified size when known.
|
||||
"""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = select(sa.literal(True)).select_from(Asset)
|
||||
|
||||
conditions = [
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
]
|
||||
|
||||
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
|
||||
# If verified size is 0 (unknown), we don't force equality.
|
||||
if size_bytes is not None:
|
||||
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
# If mtime_ns provided require the locator-state to exist and match.
|
||||
if mtime_ns is not None:
|
||||
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
|
||||
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
|
||||
|
||||
stmt = stmt.where(*conditions).limit(1)
|
||||
|
||||
row = (await session.execute(stmt)).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def ingest_fs_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: Optional[str] = None,
|
||||
info_name: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
preview_hash: Optional[str] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates or updates Asset record for a local (fs) asset.
|
||||
|
||||
Always:
|
||||
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
|
||||
- Insert AssetLocatorState if missing; else update mtime_ns if different.
|
||||
|
||||
Optionally (when info_name is provided):
|
||||
- Create an AssetInfo (no refcount changes).
|
||||
- Link provided tags to that AssetInfo.
|
||||
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
|
||||
* If False (default), silently skips unknown tags.
|
||||
|
||||
Returns flags and ids:
|
||||
{
|
||||
"asset_created": bool,
|
||||
"asset_updated": bool,
|
||||
"state_created": bool,
|
||||
"state_updated": bool,
|
||||
"asset_info_id": int | None,
|
||||
"tags_added": list[str],
|
||||
"tags_missing": list[str], # filled only when require_existing_tags=False
|
||||
}
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
datetime_now = datetime.now(timezone.utc)
|
||||
|
||||
out = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
"tags_added": [],
|
||||
"tags_missing": [],
|
||||
}
|
||||
|
||||
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
|
||||
async with session.begin_nested() as sp1:
|
||||
try:
|
||||
session.add(
|
||||
Asset(
|
||||
hash=asset_hash,
|
||||
size_bytes=int(size_bytes),
|
||||
mime_type=mime_type,
|
||||
refcount=0,
|
||||
storage_backend="fs",
|
||||
storage_locator=locator,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
out["asset_created"] = True
|
||||
except IntegrityError:
|
||||
await sp1.rollback()
|
||||
# Already exists by hash -> update selected fields if different
|
||||
existing = await session.get(Asset, asset_hash)
|
||||
if existing is not None:
|
||||
desired_size = int(size_bytes)
|
||||
if existing.size_bytes != desired_size:
|
||||
existing.size_bytes = desired_size
|
||||
existing.updated_at = datetime_now
|
||||
out["asset_updated"] = True
|
||||
else:
|
||||
# This should not occur. Log for visibility.
|
||||
logging.error("Asset %s not found after conflict; skipping update.", asset_hash)
|
||||
except Exception:
|
||||
await sp1.rollback()
|
||||
logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator)
|
||||
raise
|
||||
|
||||
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
|
||||
async with session.begin_nested() as sp2:
|
||||
try:
|
||||
session.add(
|
||||
AssetLocatorState(
|
||||
asset_hash=asset_hash,
|
||||
mtime_ns=int(mtime_ns),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
out["state_created"] = True
|
||||
except IntegrityError:
|
||||
await sp2.rollback()
|
||||
state = await session.get(AssetLocatorState, asset_hash)
|
||||
if state is not None:
|
||||
desired_mtime = int(mtime_ns)
|
||||
if state.mtime_ns != desired_mtime:
|
||||
state.mtime_ns = desired_mtime
|
||||
out["state_updated"] = True
|
||||
else:
|
||||
logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash)
|
||||
except Exception:
|
||||
await sp2.rollback()
|
||||
logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash)
|
||||
raise
|
||||
|
||||
# ---- Optional: AssetInfo + tag links ----
|
||||
if info_name:
|
||||
# 2a) Create AssetInfo (no refcount bump)
|
||||
async with session.begin_nested() as sp3:
|
||||
try:
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_hash=asset_hash,
|
||||
preview_hash=preview_hash,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
last_access_time=datetime_now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush() # get info.id
|
||||
out["asset_info_id"] = info.id
|
||||
except Exception:
|
||||
await sp3.rollback()
|
||||
logging.exception(
|
||||
"Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name
|
||||
)
|
||||
raise
|
||||
|
||||
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
# Which tags exist?
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
# Which links already exist?
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_by=added_by,
|
||||
added_at=datetime_now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
out["tags_added"] = to_add
|
||||
out["tags_missing"] = missing
|
||||
|
||||
# 2c) Rebuild metadata projection if provided
|
||||
if user_metadata is not None and out["asset_info_id"] is not None:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
async def touch_asset_infos_by_fs_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
abs_path: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> int:
|
||||
locator = os.path.abspath(abs_path)
|
||||
ts = ts or datetime.now(timezone.utc)
|
||||
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(Asset)
|
||||
.where(
|
||||
Asset.hash == AssetInfo.asset_hash,
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None),
|
||||
AssetInfo.last_access_time < ts,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.values(last_access_time=ts)
|
||||
|
||||
res = await session.execute(stmt)
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
||||
"""
|
||||
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
||||
plus a map of asset_info_id -> [tags], and the total count.
|
||||
|
||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||
"""
|
||||
# Clamp
|
||||
if limit <= 0:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
|
||||
# Build base query
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.options(contains_eager(AssetInfo.asset))
|
||||
)
|
||||
|
||||
# Filters
|
||||
if name_contains:
|
||||
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
# Sort
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
# Total count (same filters, no ordering/limit/offset)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
)
|
||||
if name_contains:
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = (await session.execute(count_stmt)).scalar_one()
|
||||
|
||||
# Fetch rows
|
||||
infos = (await session.execute(base)).scalars().unique().all()
|
||||
|
||||
# Collect tags in bulk (single query)
|
||||
id_list = [i.id for i in infos]
|
||||
tag_map: dict[int, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = await session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Replace the tag set on an AssetInfo with `tags`. Idempotent.
|
||||
Creates missing tag names as 'user'.
|
||||
"""
|
||||
desired = _normalize_tags(tags)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# current links
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
await _ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
async def update_asset_info_full(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
) -> AssetInfo:
|
||||
"""
|
||||
Update AssetInfo fields:
|
||||
- name (if provided)
|
||||
- user_metadata blob + rebuild projection (if provided)
|
||||
- replace tags with provided set (if provided)
|
||||
Returns the updated AssetInfo.
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
if user_metadata is not None:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=user_metadata
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
added_by=added_by,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
user_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
await session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in _project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||
wanted = _normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
by_name = {t.name: t for t in existing}
|
||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||
if to_create:
|
||||
session.add_all(to_create)
|
||||
await session.flush()
|
||||
by_name.update({t.name: t for t in to_create})
|
||||
return [by_name[n] for n in wanted]
|
||||
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None,
|
||||
exclude_tags: Sequence[str] | None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = _normalize_tags(include_tags)
|
||||
exclude_tags = _normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply metadata filters using the projection table asset_info_meta.
|
||||
|
||||
Semantics:
|
||||
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
|
||||
- For None: key is missing OR key has explicit null (val_json IS NULL).
|
||||
- For list values: ANY-of the list elements matches (EXISTS for any).
|
||||
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
|
||||
"""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
subquery = (
|
||||
select(sa.literal(1))
|
||||
.select_from(AssetInfoMeta)
|
||||
.where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return sa.exists(subquery)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
# Missing OR null:
|
||||
if value is None:
|
||||
# either: no row for key OR a row for key with explicit null
|
||||
no_row_for_key = ~sa.exists(
|
||||
select(sa.literal(1))
|
||||
.select_from(AssetInfoMeta)
|
||||
.where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
# Typed scalar matches:
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
# store as Decimal for equality against NUMERIC(38,10)
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
|
||||
# Complex: compare JSON (no index, but supported)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
# ANY-of (exists for any element)
|
||||
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def _is_scalar(v: Any) -> bool:
|
||||
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
"""
|
||||
Turn a metadata key/value into one or more projection rows:
|
||||
- scalar -> one row (ordinal=0) in the proper typed column
|
||||
- list of scalars -> one row per element with ordinal=i
|
||||
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
|
||||
- None -> single row with val_json = None
|
||||
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
# None
|
||||
if value is None:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": None})
|
||||
return rows
|
||||
|
||||
# Scalars
|
||||
if _is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
# store numeric; SQLAlchemy will coerce to Numeric
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": value})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
# Fallback to json
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
# Lists
|
||||
if isinstance(value, list):
|
||||
# list of scalars?
|
||||
if all(_is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": None})
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
rows.append({"key": key, "ordinal": i, "val_num": x})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
# list contains objects -> one val_json per element
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
# Dict or any other structure -> single json row
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
@@ -1,331 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from folder_paths import get_relative_path, get_full_path
|
||||
from app.database.db import create_session, dependencies_available, can_create_session
|
||||
import blake3
|
||||
import comfy.utils
|
||||
|
||||
|
||||
if dependencies_available():
|
||||
from app.database.models import Model
|
||||
|
||||
|
||||
class ModelProcessor:
|
||||
def _validate_path(self, model_path):
|
||||
try:
|
||||
if not self._file_exists(model_path):
|
||||
logging.error(f"Model file not found: {model_path}")
|
||||
return None
|
||||
|
||||
result = get_relative_path(model_path)
|
||||
if not result:
|
||||
logging.error(
|
||||
f"Model file not in a recognized model directory: {model_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _file_exists(self, path):
|
||||
"""Check if a file exists."""
|
||||
return os.path.exists(path)
|
||||
|
||||
def _get_file_size(self, path):
|
||||
"""Get file size."""
|
||||
return os.path.getsize(path)
|
||||
|
||||
def _get_hasher(self):
|
||||
return blake3.blake3()
|
||||
|
||||
def _hash_file(self, model_path):
|
||||
try:
|
||||
hasher = self._get_hasher()
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
except Exception as e:
|
||||
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_existing_model(self, session, model_type, model_relative_path):
|
||||
return (
|
||||
session.query(Model)
|
||||
.filter(Model.type == model_type)
|
||||
.filter(Model.path == model_relative_path)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _ensure_source_url(self, session, model, source_url):
|
||||
if model.source_url is None:
|
||||
model.source_url = source_url
|
||||
session.commit()
|
||||
|
||||
def _update_database(
|
||||
self,
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
model,
|
||||
source_url,
|
||||
):
|
||||
try:
|
||||
if not model:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = Model(
|
||||
path=model_relative_path,
|
||||
type=model_type,
|
||||
file_name=os.path.basename(model_path),
|
||||
)
|
||||
session.add(model)
|
||||
|
||||
model.file_size = self._get_file_size(model_path)
|
||||
model.hash = model_hash
|
||||
if model_hash:
|
||||
model.hash_algorithm = "blake3"
|
||||
model.source_url = source_url
|
||||
|
||||
session.commit()
|
||||
return model
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error updating database for {model_relative_path}: {str(e)}"
|
||||
)
|
||||
|
||||
def process_file(self, model_path, source_url=None, model_hash=None):
|
||||
"""
|
||||
Process a model file and update the database with metadata.
|
||||
If the file already exists and matches the database, it will not be processed again.
|
||||
Returns the model object or if an error occurs, returns None.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
session.expire_on_commit = False
|
||||
|
||||
existing_model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if (
|
||||
existing_model
|
||||
and existing_model.hash
|
||||
and existing_model.file_size == self._get_file_size(model_path)
|
||||
):
|
||||
# File exists with hash and same size, no need to process
|
||||
self._ensure_source_url(session, existing_model, source_url)
|
||||
return existing_model
|
||||
|
||||
if model_hash:
|
||||
model_hash = model_hash.lower()
|
||||
logging.info(f"Using provided hash: {model_hash}")
|
||||
else:
|
||||
start_time = time.time()
|
||||
logging.info(f"Hashing model {model_relative_path}")
|
||||
model_hash = self._hash_file(model_path)
|
||||
if not model_hash:
|
||||
return
|
||||
logging.info(
|
||||
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
||||
)
|
||||
|
||||
return self._update_database(
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
existing_model,
|
||||
source_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
||||
"""
|
||||
Retrieve a model file from the database by hash and optionally by model type.
|
||||
Returns the model object or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
dispose_session = False
|
||||
|
||||
if session is None:
|
||||
session = create_session()
|
||||
dispose_session = True
|
||||
|
||||
model = session.query(Model).filter(Model.hash == model_hash)
|
||||
if model_type is not None:
|
||||
model = model.filter(Model.type == model_type)
|
||||
return model.first()
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
if dispose_session:
|
||||
session.close()
|
||||
|
||||
def retrieve_hash(self, model_path, model_type=None):
|
||||
"""
|
||||
Retrieve the hash of a model file from the database.
|
||||
Returns the hash or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
if model_type is not None:
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return None
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if model and model.hash:
|
||||
return model.hash
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _validate_file_extension(self, file_name):
|
||||
"""Validate that the file extension is supported."""
|
||||
extension = os.path.splitext(file_name)[1]
|
||||
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
||||
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
||||
|
||||
def _check_existing_file(self, model_type, file_name, expected_hash):
|
||||
"""Check if file exists and has correct hash."""
|
||||
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
||||
if self._file_exists(destination_path):
|
||||
model = self.process_file(destination_path)
|
||||
if model and (expected_hash is None or model.hash == expected_hash):
|
||||
logging.debug(
|
||||
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
||||
)
|
||||
return destination_path
|
||||
else:
|
||||
raise ValueError(
|
||||
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
||||
)
|
||||
return None
|
||||
|
||||
def _check_existing_file_by_hash(self, hash, type, url):
|
||||
"""Check if a file with the given hash exists in the database and on disk."""
|
||||
hash = hash.lower()
|
||||
with create_session() as session:
|
||||
model = self.retrieve_model_by_hash(hash, type, session)
|
||||
if model:
|
||||
existing_path = get_full_path(type, model.path)
|
||||
if existing_path:
|
||||
logging.debug(
|
||||
f"File {model.path} already exists in the database at {existing_path}"
|
||||
)
|
||||
self._ensure_source_url(session, model, url)
|
||||
return existing_path
|
||||
else:
|
||||
logging.debug(
|
||||
f"File {model.path} exists in the database but not on disk"
|
||||
)
|
||||
return None
|
||||
|
||||
def _download_file(self, url, destination_path, hasher):
|
||||
"""Download a file and update the hasher with its contents."""
|
||||
response = requests.get(url, stream=True)
|
||||
logging.info(f"Downloading {url} to {destination_path}")
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
if total_size > 0:
|
||||
pbar = comfy.utils.ProgressBar(total_size)
|
||||
else:
|
||||
pbar = None
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
||||
for chunk in response.iter_content(chunk_size=128 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
if pbar:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
||||
"""Verify that the downloaded file has the expected hash."""
|
||||
if expected_hash is not None and calculated_hash != expected_hash:
|
||||
self._remove_file(destination_path)
|
||||
raise ValueError(
|
||||
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
||||
)
|
||||
|
||||
def _remove_file(self, file_path):
|
||||
"""Remove a file from disk."""
|
||||
os.remove(file_path)
|
||||
|
||||
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
||||
"""
|
||||
Ensure a model file is downloaded and has the correct hash.
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
logging.debug(
|
||||
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
self._validate_file_extension(desired_file_name)
|
||||
|
||||
# Check if file exists with correct hash
|
||||
if hash:
|
||||
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Check if file exists locally
|
||||
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
||||
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Download the file
|
||||
hasher = self._get_hasher()
|
||||
self._download_file(url, destination_path, hasher)
|
||||
|
||||
# Verify hash
|
||||
calculated_hash = hasher.hexdigest()
|
||||
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
||||
|
||||
# Update database
|
||||
self.process_file(destination_path, url, calculated_hash)
|
||||
|
||||
# TODO: Notify frontend to reload models
|
||||
|
||||
return destination_path
|
||||
|
||||
|
||||
model_processor = ModelProcessor()
|
||||
0
app/storage/__init__.py
Normal file
0
app/storage/__init__.py
Normal file
72
app/storage/hashing.py
Normal file
72
app/storage/hashing.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import IO, Union
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
|
||||
|
||||
|
||||
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
|
||||
"""Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
orig_pos = None
|
||||
if hasattr(file_obj, "tell"):
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
if hasattr(file_obj, "seek"):
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
if hasattr(file_obj, "seek") and orig_pos is not None:
|
||||
file_obj.seek(orig_pos)
|
||||
|
||||
|
||||
def blake3_hash_sync(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj_sync(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
Reference in New Issue
Block a user