Add model management and database

- use sqlalchemy + alembic + sqlite for db
- extract model data and previews
- endpoints for db interactions
- add tests
This commit is contained in:
pythongosssss
2025-03-28 11:39:56 +08:00
parent 1709a8441e
commit 7bf381bc9e
14 changed files with 1264 additions and 40 deletions

116
app/database/db.py Normal file
View File

@@ -0,0 +1,116 @@
import logging
import os
import shutil
import sys
from app.database.models import Tag
from comfy.cli_args import args
try:
import alembic
import sqlalchemy
except ImportError as e:
req_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../..", "requirements.txt")
)
logging.error(
f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n"
)
exit(-1)
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
Session = None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path():
url = args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
raise ValueError(f"Unsupported database URL '{url}'.")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
config = get_alembic_config()
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(config)
target_rev = script.get_current_head()
if current_rev != target_rev:
# Backup the database pre upgrade
db_path = get_db_path()
backup_path = db_path + ".bkp"
if os.path.exists(db_path):
shutil.copy(db_path, backup_path)
else:
backup_path = None
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.error(f"Error upgrading database: {e}")
raise e
global Session
Session = sessionmaker(bind=engine)
if not current_rev:
# Init db, populate models
from app.model_processor import model_processor
session = create_session()
model_processor.populate_models(session)
# populate tags
tags = (
"character",
"style",
"concept",
"clothing",
"pose",
"background",
"vehicle",
"object",
"animal",
"action",
)
for tag in tags:
session.add(Tag(name=tag))
session.commit()
def create_session():
return Session()

76
app/database/models.py Normal file
View File

@@ -0,0 +1,76 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
Table,
ForeignKeyConstraint,
)
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
def to_dict(obj):
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
ModelTag = Table(
"model_tag",
Base.metadata,
Column(
"model_type",
Text,
primary_key=True,
),
Column(
"model_path",
Text,
primary_key=True,
),
Column("tag_id", Integer, primary_key=True),
ForeignKeyConstraint(
["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE"
),
ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"),
)
class Model(Base):
__tablename__ = "model"
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
title = Column(Text)
description = Column(Text)
architecture = Column(Text)
hash = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
# Relationship with tags
tags = relationship("Tag", secondary=ModelTag, back_populates="models")
def to_dict(self):
dict = to_dict(self)
dict["tags"] = [tag.to_dict() for tag in self.tags]
return dict
class Tag(Base):
__tablename__ = "tag"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=False, unique=True)
# Relationship with models
models = relationship("Model", secondary=ModelTag, back_populates="tags")
def to_dict(self):
return to_dict(self)

View File

@@ -1,19 +1,30 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
from app.database.db import create_session
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
from folder_paths import map_legacy, filter_files_extensions, get_full_path
from app.database.models import Tag, Model
from app.model_processor import get_model_previews, model_processor
from utils.web import dumps
from sqlalchemy.orm import joinedload
import sqlalchemy.exc
def bad_request(message: str):
return web.json_response({"error": message}, status=400)
def missing_field(field: str):
return bad_request(f"{field} is required")
def not_found(message: str):
return web.json_response({"error": message + " not found"}, status=404)
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
@@ -62,7 +73,7 @@ class ModelFileManager:
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
previews = get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
@@ -76,6 +87,183 @@ class ModelFileManager:
except:
return web.Response(status=404)
@routes.get("/v2/models")
async def get_models(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
query = session.query(Model).options(joinedload(Model.tags))
if model_path:
query = query.filter(Model.path == model_path)
if model_type:
query = query.filter(Model.type == model_type)
models = query.all()
if model_path and model_type:
if len(models) == 0:
return not_found("Model")
return web.json_response(models[0].to_dict(), dumps=dumps)
return web.json_response([model.to_dict() for model in models], dumps=dumps)
@routes.post("/v2/models")
async def add_model(request):
with create_session() as session:
data = await request.json()
model_type = data.get("type", None)
model_path = data.get("path", None)
if not model_type:
return missing_field("type")
if not model_path:
return missing_field("path")
tags = data.pop("tags", [])
fields = Model.metadata.tables["model"].columns.keys()
# Validate keys are valid model fields
for key in data.keys():
if key not in fields:
return bad_request(f"Invalid field: {key}")
# Validate file exists
if not get_full_path(model_type, model_path):
return not_found(f"File '{model_type}/{model_path}'")
model = Model()
for field in fields:
if field in data:
setattr(model, field, data[field])
model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all()
for tag in tags:
if tag not in [t.id for t in model.tags]:
return not_found(f"Tag '{tag}'")
try:
session.add(model)
session.commit()
except sqlalchemy.exc.IntegrityError as e:
session.rollback()
return bad_request(e.orig.args[0])
model_processor.run()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models")
async def delete_model(request):
with create_session() as session:
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if not model_path:
return missing_field("path")
if not model_type:
return missing_field("type")
full_path = get_full_path(model_type, model_path)
if full_path:
return bad_request("Model file exists, please delete the file before deleting the model record.")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
session.delete(model)
session.commit()
return web.Response(status=204)
@routes.get("/v2/tags")
async def get_tags(request):
with create_session() as session:
tags = session.query(Tag).all()
return web.json_response(
[{"id": tag.id, "name": tag.name} for tag in tags]
)
@routes.post("/v2/tags")
async def create_tag(request):
with create_session() as session:
data = await request.json()
name = data.get("name", None)
if not name:
return missing_field("name")
tag = Tag(name=name)
session.add(tag)
session.commit()
return web.json_response({"id": tag.id, "name": tag.name})
@routes.delete("/v2/tags")
async def delete_tag(request):
with create_session() as session:
tag_id = request.query.get("id", None)
if not tag_id:
return missing_field("id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
return not_found("Tag")
session.delete(tag)
session.commit()
return web.Response(status=204)
@routes.post("/v2/models/tags")
async def add_model_tag(request):
with create_session() as session:
data = await request.json()
tag_id = data.get("tag", None)
model_path = data.get("path", None)
model_type = data.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
tag = session.query(Tag).filter(Tag.id == tag_id).first()
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags.append(tag)
session.commit()
return web.json_response(model.to_dict(), dumps=dumps)
@routes.delete("/v2/models/tags")
async def delete_model_tag(request):
with create_session() as session:
tag_id = request.query.get("tag", None)
model_path = request.query.get("path", None)
model_type = request.query.get("type", None)
if tag_id is None:
return missing_field("tag")
if model_path is None:
return missing_field("path")
if model_type is None:
return missing_field("type")
try:
tag_id = int(tag_id)
except ValueError:
return bad_request("Invalid tag id")
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
if not model:
return not_found("Model")
model.tags = [tag for tag in model.tags if tag.id != tag_id]
session.commit()
return web.Response(status=204)
@routes.get("/v2/models/missing")
async def get_missing_models(request):
return web.json_response(model_processor.missing_models)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
@@ -146,39 +334,5 @@ class ModelFileManager:
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

263
app/model_processor.py Normal file
View File

@@ -0,0 +1,263 @@
import base64
from datetime import datetime
import glob
import hashlib
from io import BytesIO
import json
import logging
import os
import threading
import time
import comfy.utils
from app.database.models import Model
from app.database.db import create_session
from comfy.cli_args import args
from folder_paths import (
filter_files_content_types,
get_full_path,
folder_names_and_paths,
get_filename_list,
)
from PIL import Image
from urllib import request
def get_model_previews(
filepath: str, check_metadata: bool = True
) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if not check_metadata:
return result
safetensors_file = next(
filter(lambda x: x.endswith(".safetensors"), match_files), None
)
safetensors_metadata = {}
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(
safetensors_filepath, max_size=8 * 1024 * 1024
)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get(
"ssmd_cover_images", None
)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
class ModelProcessor:
def __init__(self):
self._thread = None
self._lock = threading.Lock()
self._run = False
self.missing_models = []
def run(self):
if args.disable_model_processing:
return
if self._thread is None:
# Lock to prevent multiple threads from starting
with self._lock:
self._run = True
if self._thread is None:
self._thread = threading.Thread(target=self._process_models)
self._thread.daemon = True
self._thread.start()
def populate_models(self, session):
# Ensure database state matches filesystem
existing_models = session.query(Model).all()
for folder_name in folder_names_and_paths.keys():
if folder_name == "custom_nodes" or folder_name == "configs":
continue
seen = set()
files = get_filename_list(folder_name)
for file in files:
if file in seen:
logging.warning(f"Skipping duplicate named model: {file}")
continue
seen.add(file)
existing_model = None
for model in existing_models:
if model.path == file and model.type == folder_name:
existing_model = model
break
if existing_model:
# Model already exists in db, remove from list and skip
existing_models.remove(existing_model)
continue
file_path = get_full_path(folder_name, file)
model = Model(
path=file,
type=folder_name,
date_added=datetime.fromtimestamp(os.path.getctime(file_path)),
)
session.add(model)
for model in existing_models:
if not get_full_path(model.type, model.path):
logging.warning(f"Model {model.path} not found")
self.missing_models.append({"type": model.type, "path": model.path})
session.commit()
def _get_models(self, session):
models = session.query(Model).filter(Model.hash == None).all()
return models
def _process_file(self, model_path):
is_safetensors = model_path.endswith(".safetensors")
metadata = {}
h = hashlib.sha256()
with open(model_path, "rb", buffering=0) as f:
if is_safetensors:
# Read header length (8 bytes)
header_size_bytes = f.read(8)
header_len = int.from_bytes(header_size_bytes, "little")
h.update(header_size_bytes)
# Read header
header_bytes = f.read(header_len)
h.update(header_bytes)
try:
metadata = json.loads(header_bytes)
except json.JSONDecodeError:
pass
# Read rest of file
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest(), metadata
def _populate_info(self, model, metadata):
model.title = metadata.get("modelspec.title", None)
model.description = metadata.get("modelspec.description", None)
model.architecture = metadata.get("modelspec.architecture", None)
def _extract_image(self, model_path, metadata):
# check if image already exists
if len(get_model_previews(model_path, check_metadata=False)) > 0:
return
image_path = os.path.splitext(model_path)[0] + ".webp"
if os.path.exists(image_path):
return
cover_images = metadata.get("ssmd_cover_images", None)
image = None
if cover_images:
try:
cover_images = json.loads(cover_images)
if len(cover_images) > 0:
image_data = cover_images[0]
image = Image.open(BytesIO(base64.b64decode(image_data)))
except Exception as e:
logging.warning(
f"Error extracting cover image for model {model_path}: {e}"
)
if not image:
thumbnail = metadata.get("modelspec.thumbnail", None)
if thumbnail:
try:
response = request.urlopen(thumbnail)
image = Image.open(response)
except Exception as e:
logging.warning(
f"Error extracting thumbnail for model {model_path}: {e}"
)
if image:
image.thumbnail((512, 512))
image.save(image_path)
image.close()
def _process_models(self):
with create_session() as session:
checked = set()
self.populate_models(session)
while self._run:
self._run = False
models = self._get_models(session)
if len(models) == 0:
break
for model in models:
# prevent looping on the same model if it crashes
if model.path in checked:
continue
checked.add(model.path)
try:
time.sleep(0)
now = time.time()
model_path = get_full_path(model.type, model.path)
if not model_path:
logging.warning(f"Model {model.path} not found")
self.missing_models.append(model.path)
continue
logging.debug(f"Processing model {model_path}")
hash, header = self._process_file(model_path)
logging.debug(
f"Processed model {model_path} in {time.time() - now} seconds"
)
model.hash = hash
if header:
metadata = header.get("__metadata__", None)
if metadata:
self._populate_info(model, metadata)
self._extract_image(model_path, metadata)
session.commit()
except Exception as e:
logging.error(f"Error processing model {model.path}: {e}")
with self._lock:
self._thread = None
model_processor = ModelProcessor()