mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-11 08:10:06 +00:00
Add model management and database
- use sqlalchemy + alembic + sqlite for db - extract model data and previews - endpoints for db interactions - add tests
This commit is contained in:
116
app/database/db.py
Normal file
116
app/database/db.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from app.database.models import Tag
|
||||
from comfy.cli_args import args
|
||||
|
||||
try:
|
||||
import alembic
|
||||
import sqlalchemy
|
||||
except ImportError as e:
|
||||
req_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../..", "requirements.txt")
|
||||
)
|
||||
logging.error(
|
||||
f"\n\n********** ERROR ***********\n\nRequirements are not installed ({e}). Please install the requirements.txt file by running:\n{sys.executable} -s -m pip install -r {req_path}\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem\n********** ERROR **********\n"
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
Session = None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
db_path = get_db_path()
|
||||
backup_path = db_path + ".bkp"
|
||||
if os.path.exists(db_path):
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.error(f"Error upgrading database: {e}")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
if not current_rev:
|
||||
# Init db, populate models
|
||||
from app.model_processor import model_processor
|
||||
|
||||
session = create_session()
|
||||
model_processor.populate_models(session)
|
||||
|
||||
# populate tags
|
||||
tags = (
|
||||
"character",
|
||||
"style",
|
||||
"concept",
|
||||
"clothing",
|
||||
"pose",
|
||||
"background",
|
||||
"vehicle",
|
||||
"object",
|
||||
"animal",
|
||||
"action",
|
||||
)
|
||||
for tag in tags:
|
||||
session.add(Tag(name=tag))
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
76
app/database/models.py
Normal file
76
app/database/models.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
Text,
|
||||
DateTime,
|
||||
Table,
|
||||
ForeignKeyConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
|
||||
ModelTag = Table(
|
||||
"model_tag",
|
||||
Base.metadata,
|
||||
Column(
|
||||
"model_type",
|
||||
Text,
|
||||
primary_key=True,
|
||||
),
|
||||
Column(
|
||||
"model_path",
|
||||
Text,
|
||||
primary_key=True,
|
||||
),
|
||||
Column("tag_id", Integer, primary_key=True),
|
||||
ForeignKeyConstraint(
|
||||
["model_type", "model_path"], ["model.type", "model.path"], ondelete="CASCADE"
|
||||
),
|
||||
ForeignKeyConstraint(["tag_id"], ["tag.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
|
||||
class Model(Base):
|
||||
__tablename__ = "model"
|
||||
|
||||
type = Column(Text, primary_key=True)
|
||||
path = Column(Text, primary_key=True)
|
||||
title = Column(Text)
|
||||
description = Column(Text)
|
||||
architecture = Column(Text)
|
||||
hash = Column(Text)
|
||||
source_url = Column(Text)
|
||||
date_added = Column(DateTime, server_default=func.now())
|
||||
|
||||
# Relationship with tags
|
||||
tags = relationship("Tag", secondary=ModelTag, back_populates="models")
|
||||
|
||||
def to_dict(self):
|
||||
dict = to_dict(self)
|
||||
dict["tags"] = [tag.to_dict() for tag in self.tags]
|
||||
return dict
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(Text, nullable=False, unique=True)
|
||||
|
||||
# Relationship with models
|
||||
models = relationship("Model", secondary=ModelTag, back_populates="tags")
|
||||
|
||||
def to_dict(self):
|
||||
return to_dict(self)
|
||||
@@ -1,19 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from app.database.db import create_session
|
||||
import folder_paths
|
||||
import glob
|
||||
import comfy.utils
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||
from folder_paths import map_legacy, filter_files_extensions, get_full_path
|
||||
from app.database.models import Tag, Model
|
||||
from app.model_processor import get_model_previews, model_processor
|
||||
from utils.web import dumps
|
||||
from sqlalchemy.orm import joinedload
|
||||
import sqlalchemy.exc
|
||||
|
||||
|
||||
def bad_request(message: str):
|
||||
return web.json_response({"error": message}, status=400)
|
||||
|
||||
def missing_field(field: str):
|
||||
return bad_request(f"{field} is required")
|
||||
|
||||
def not_found(message: str):
|
||||
return web.json_response({"error": message + " not found"}, status=404)
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
@@ -62,7 +73,7 @@ class ModelFileManager:
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.join(folder, filename)
|
||||
|
||||
previews = self.get_model_previews(full_filename)
|
||||
previews = get_model_previews(full_filename)
|
||||
default_preview = previews[0] if len(previews) > 0 else None
|
||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||
return web.Response(status=404)
|
||||
@@ -76,6 +87,183 @@ class ModelFileManager:
|
||||
except:
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/v2/models")
|
||||
async def get_models(request):
|
||||
with create_session() as session:
|
||||
model_path = request.query.get("path", None)
|
||||
model_type = request.query.get("type", None)
|
||||
query = session.query(Model).options(joinedload(Model.tags))
|
||||
if model_path:
|
||||
query = query.filter(Model.path == model_path)
|
||||
if model_type:
|
||||
query = query.filter(Model.type == model_type)
|
||||
models = query.all()
|
||||
if model_path and model_type:
|
||||
if len(models) == 0:
|
||||
return not_found("Model")
|
||||
return web.json_response(models[0].to_dict(), dumps=dumps)
|
||||
|
||||
return web.json_response([model.to_dict() for model in models], dumps=dumps)
|
||||
|
||||
@routes.post("/v2/models")
|
||||
async def add_model(request):
|
||||
with create_session() as session:
|
||||
data = await request.json()
|
||||
model_type = data.get("type", None)
|
||||
model_path = data.get("path", None)
|
||||
|
||||
if not model_type:
|
||||
return missing_field("type")
|
||||
if not model_path:
|
||||
return missing_field("path")
|
||||
|
||||
tags = data.pop("tags", [])
|
||||
fields = Model.metadata.tables["model"].columns.keys()
|
||||
|
||||
# Validate keys are valid model fields
|
||||
for key in data.keys():
|
||||
if key not in fields:
|
||||
return bad_request(f"Invalid field: {key}")
|
||||
|
||||
# Validate file exists
|
||||
if not get_full_path(model_type, model_path):
|
||||
return not_found(f"File '{model_type}/{model_path}'")
|
||||
|
||||
model = Model()
|
||||
for field in fields:
|
||||
if field in data:
|
||||
setattr(model, field, data[field])
|
||||
|
||||
model.tags = session.query(Tag).filter(Tag.id.in_(tags)).all()
|
||||
for tag in tags:
|
||||
if tag not in [t.id for t in model.tags]:
|
||||
return not_found(f"Tag '{tag}'")
|
||||
|
||||
try:
|
||||
session.add(model)
|
||||
session.commit()
|
||||
except sqlalchemy.exc.IntegrityError as e:
|
||||
session.rollback()
|
||||
return bad_request(e.orig.args[0])
|
||||
|
||||
model_processor.run()
|
||||
|
||||
return web.json_response(model.to_dict(), dumps=dumps)
|
||||
|
||||
@routes.delete("/v2/models")
|
||||
async def delete_model(request):
|
||||
with create_session() as session:
|
||||
model_path = request.query.get("path", None)
|
||||
model_type = request.query.get("type", None)
|
||||
if not model_path:
|
||||
return missing_field("path")
|
||||
if not model_type:
|
||||
return missing_field("type")
|
||||
|
||||
full_path = get_full_path(model_type, model_path)
|
||||
if full_path:
|
||||
return bad_request("Model file exists, please delete the file before deleting the model record.")
|
||||
|
||||
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
||||
if not model:
|
||||
return not_found("Model")
|
||||
session.delete(model)
|
||||
session.commit()
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.get("/v2/tags")
|
||||
async def get_tags(request):
|
||||
with create_session() as session:
|
||||
tags = session.query(Tag).all()
|
||||
return web.json_response(
|
||||
[{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
)
|
||||
|
||||
@routes.post("/v2/tags")
|
||||
async def create_tag(request):
|
||||
with create_session() as session:
|
||||
data = await request.json()
|
||||
name = data.get("name", None)
|
||||
if not name:
|
||||
return missing_field("name")
|
||||
tag = Tag(name=name)
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
return web.json_response({"id": tag.id, "name": tag.name})
|
||||
|
||||
@routes.delete("/v2/tags")
|
||||
async def delete_tag(request):
|
||||
with create_session() as session:
|
||||
tag_id = request.query.get("id", None)
|
||||
if not tag_id:
|
||||
return missing_field("id")
|
||||
tag = session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
return not_found("Tag")
|
||||
session.delete(tag)
|
||||
session.commit()
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.post("/v2/models/tags")
|
||||
async def add_model_tag(request):
|
||||
with create_session() as session:
|
||||
data = await request.json()
|
||||
tag_id = data.get("tag", None)
|
||||
model_path = data.get("path", None)
|
||||
model_type = data.get("type", None)
|
||||
|
||||
if tag_id is None:
|
||||
return missing_field("tag")
|
||||
if model_path is None:
|
||||
return missing_field("path")
|
||||
if model_type is None:
|
||||
return missing_field("type")
|
||||
|
||||
try:
|
||||
tag_id = int(tag_id)
|
||||
except ValueError:
|
||||
return bad_request("Invalid tag id")
|
||||
|
||||
tag = session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
||||
if not model:
|
||||
return not_found("Model")
|
||||
model.tags.append(tag)
|
||||
session.commit()
|
||||
return web.json_response(model.to_dict(), dumps=dumps)
|
||||
|
||||
@routes.delete("/v2/models/tags")
|
||||
async def delete_model_tag(request):
|
||||
with create_session() as session:
|
||||
tag_id = request.query.get("tag", None)
|
||||
model_path = request.query.get("path", None)
|
||||
model_type = request.query.get("type", None)
|
||||
|
||||
if tag_id is None:
|
||||
return missing_field("tag")
|
||||
if model_path is None:
|
||||
return missing_field("path")
|
||||
if model_type is None:
|
||||
return missing_field("type")
|
||||
|
||||
try:
|
||||
tag_id = int(tag_id)
|
||||
except ValueError:
|
||||
return bad_request("Invalid tag id")
|
||||
|
||||
model = session.query(Model).filter(Model.path == model_path, Model.type == model_type).first()
|
||||
if not model:
|
||||
return not_found("Model")
|
||||
model.tags = [tag for tag in model.tags if tag.id != tag_id]
|
||||
session.commit()
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
|
||||
@routes.get("/v2/models/missing")
|
||||
async def get_missing_models(request):
|
||||
return web.json_response(model_processor.missing_models)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
@@ -146,39 +334,5 @@ class ModelFileManager:
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
return []
|
||||
|
||||
basename = os.path.splitext(filepath)[0]
|
||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||
image_files = filter_files_content_types(match_files, "image")
|
||||
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
||||
safetensors_metadata = {}
|
||||
|
||||
result: list[str | BytesIO] = []
|
||||
|
||||
for filename in image_files:
|
||||
_basename = os.path.splitext(filename)[0]
|
||||
if _basename == basename:
|
||||
result.append(filename)
|
||||
if _basename == f"{basename}.preview":
|
||||
result.append(filename)
|
||||
|
||||
if safetensors_file:
|
||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
||||
if header:
|
||||
safetensors_metadata = json.loads(header)
|
||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
||||
if safetensors_images:
|
||||
safetensors_images = json.loads(safetensors_images)
|
||||
for image in safetensors_images:
|
||||
result.append(BytesIO(base64.b64decode(image)))
|
||||
|
||||
return result
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.clear_cache()
|
||||
|
||||
263
app/model_processor.py
Normal file
263
app/model_processor.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
import glob
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import comfy.utils
|
||||
from app.database.models import Model
|
||||
from app.database.db import create_session
|
||||
from comfy.cli_args import args
|
||||
from folder_paths import (
|
||||
filter_files_content_types,
|
||||
get_full_path,
|
||||
folder_names_and_paths,
|
||||
get_filename_list,
|
||||
)
|
||||
from PIL import Image
|
||||
from urllib import request
|
||||
|
||||
|
||||
def get_model_previews(
|
||||
filepath: str, check_metadata: bool = True
|
||||
) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
return []
|
||||
|
||||
basename = os.path.splitext(filepath)[0]
|
||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||
image_files = filter_files_content_types(match_files, "image")
|
||||
|
||||
result: list[str | BytesIO] = []
|
||||
|
||||
for filename in image_files:
|
||||
_basename = os.path.splitext(filename)[0]
|
||||
if _basename == basename:
|
||||
result.append(filename)
|
||||
if _basename == f"{basename}.preview":
|
||||
result.append(filename)
|
||||
|
||||
if not check_metadata:
|
||||
return result
|
||||
|
||||
safetensors_file = next(
|
||||
filter(lambda x: x.endswith(".safetensors"), match_files), None
|
||||
)
|
||||
safetensors_metadata = {}
|
||||
|
||||
if safetensors_file:
|
||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||
header = comfy.utils.safetensors_header(
|
||||
safetensors_filepath, max_size=8 * 1024 * 1024
|
||||
)
|
||||
if header:
|
||||
safetensors_metadata = json.loads(header)
|
||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get(
|
||||
"ssmd_cover_images", None
|
||||
)
|
||||
if safetensors_images:
|
||||
safetensors_images = json.loads(safetensors_images)
|
||||
for image in safetensors_images:
|
||||
result.append(BytesIO(base64.b64decode(image)))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ModelProcessor:
|
||||
def __init__(self):
|
||||
self._thread = None
|
||||
self._lock = threading.Lock()
|
||||
self._run = False
|
||||
self.missing_models = []
|
||||
|
||||
def run(self):
|
||||
if args.disable_model_processing:
|
||||
return
|
||||
|
||||
if self._thread is None:
|
||||
# Lock to prevent multiple threads from starting
|
||||
with self._lock:
|
||||
self._run = True
|
||||
if self._thread is None:
|
||||
self._thread = threading.Thread(target=self._process_models)
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
|
||||
def populate_models(self, session):
|
||||
# Ensure database state matches filesystem
|
||||
|
||||
existing_models = session.query(Model).all()
|
||||
|
||||
for folder_name in folder_names_and_paths.keys():
|
||||
if folder_name == "custom_nodes" or folder_name == "configs":
|
||||
continue
|
||||
seen = set()
|
||||
files = get_filename_list(folder_name)
|
||||
|
||||
for file in files:
|
||||
if file in seen:
|
||||
logging.warning(f"Skipping duplicate named model: {file}")
|
||||
continue
|
||||
seen.add(file)
|
||||
|
||||
existing_model = None
|
||||
for model in existing_models:
|
||||
if model.path == file and model.type == folder_name:
|
||||
existing_model = model
|
||||
break
|
||||
|
||||
if existing_model:
|
||||
# Model already exists in db, remove from list and skip
|
||||
existing_models.remove(existing_model)
|
||||
continue
|
||||
|
||||
file_path = get_full_path(folder_name, file)
|
||||
|
||||
model = Model(
|
||||
path=file,
|
||||
type=folder_name,
|
||||
date_added=datetime.fromtimestamp(os.path.getctime(file_path)),
|
||||
)
|
||||
session.add(model)
|
||||
|
||||
for model in existing_models:
|
||||
if not get_full_path(model.type, model.path):
|
||||
logging.warning(f"Model {model.path} not found")
|
||||
self.missing_models.append({"type": model.type, "path": model.path})
|
||||
|
||||
session.commit()
|
||||
|
||||
def _get_models(self, session):
|
||||
models = session.query(Model).filter(Model.hash == None).all()
|
||||
return models
|
||||
|
||||
def _process_file(self, model_path):
|
||||
is_safetensors = model_path.endswith(".safetensors")
|
||||
metadata = {}
|
||||
h = hashlib.sha256()
|
||||
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
if is_safetensors:
|
||||
# Read header length (8 bytes)
|
||||
header_size_bytes = f.read(8)
|
||||
header_len = int.from_bytes(header_size_bytes, "little")
|
||||
h.update(header_size_bytes)
|
||||
|
||||
# Read header
|
||||
header_bytes = f.read(header_len)
|
||||
h.update(header_bytes)
|
||||
try:
|
||||
metadata = json.loads(header_bytes)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Read rest of file
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
|
||||
return h.hexdigest(), metadata
|
||||
|
||||
def _populate_info(self, model, metadata):
|
||||
model.title = metadata.get("modelspec.title", None)
|
||||
model.description = metadata.get("modelspec.description", None)
|
||||
model.architecture = metadata.get("modelspec.architecture", None)
|
||||
|
||||
def _extract_image(self, model_path, metadata):
|
||||
# check if image already exists
|
||||
if len(get_model_previews(model_path, check_metadata=False)) > 0:
|
||||
return
|
||||
|
||||
image_path = os.path.splitext(model_path)[0] + ".webp"
|
||||
if os.path.exists(image_path):
|
||||
return
|
||||
|
||||
cover_images = metadata.get("ssmd_cover_images", None)
|
||||
image = None
|
||||
if cover_images:
|
||||
try:
|
||||
cover_images = json.loads(cover_images)
|
||||
if len(cover_images) > 0:
|
||||
image_data = cover_images[0]
|
||||
image = Image.open(BytesIO(base64.b64decode(image_data)))
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Error extracting cover image for model {model_path}: {e}"
|
||||
)
|
||||
|
||||
if not image:
|
||||
thumbnail = metadata.get("modelspec.thumbnail", None)
|
||||
if thumbnail:
|
||||
try:
|
||||
response = request.urlopen(thumbnail)
|
||||
image = Image.open(response)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Error extracting thumbnail for model {model_path}: {e}"
|
||||
)
|
||||
|
||||
if image:
|
||||
image.thumbnail((512, 512))
|
||||
image.save(image_path)
|
||||
image.close()
|
||||
|
||||
def _process_models(self):
|
||||
with create_session() as session:
|
||||
checked = set()
|
||||
self.populate_models(session)
|
||||
|
||||
while self._run:
|
||||
self._run = False
|
||||
|
||||
models = self._get_models(session)
|
||||
|
||||
if len(models) == 0:
|
||||
break
|
||||
|
||||
for model in models:
|
||||
# prevent looping on the same model if it crashes
|
||||
if model.path in checked:
|
||||
continue
|
||||
|
||||
checked.add(model.path)
|
||||
|
||||
try:
|
||||
time.sleep(0)
|
||||
now = time.time()
|
||||
model_path = get_full_path(model.type, model.path)
|
||||
|
||||
if not model_path:
|
||||
logging.warning(f"Model {model.path} not found")
|
||||
self.missing_models.append(model.path)
|
||||
continue
|
||||
|
||||
logging.debug(f"Processing model {model_path}")
|
||||
hash, header = self._process_file(model_path)
|
||||
logging.debug(
|
||||
f"Processed model {model_path} in {time.time() - now} seconds"
|
||||
)
|
||||
model.hash = hash
|
||||
|
||||
if header:
|
||||
metadata = header.get("__metadata__", None)
|
||||
|
||||
if metadata:
|
||||
self._populate_info(model, metadata)
|
||||
self._extract_image(model_path, metadata)
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing model {model.path}: {e}")
|
||||
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
|
||||
model_processor = ModelProcessor()
|
||||
Reference in New Issue
Block a user