mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 05:00:03 +00:00
Compare commits
8 Commits
v0.3.50
...
pysssss-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3089936a2c | ||
|
|
cd679129e3 | ||
|
|
d7062277a7 | ||
|
|
54cf14cbbb | ||
|
|
7d5160f92c | ||
|
|
7f7b3f1695 | ||
|
|
9da6aca0d0 | ||
|
|
1cb3c98947 |
15
.github/workflows/stable-release.yml
vendored
15
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "128"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "13"
|
||||
default: "12"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "10"
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -66,13 +66,8 @@ jobs:
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
|
||||
cd ..
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
@@ -17,19 +17,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "13"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
10
.github/workflows/windows_release_package.yml
vendored
10
.github/workflows/windows_release_package.yml
vendored
@@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "13"
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -64,10 +64,6 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
|
||||
40
alembic_db/versions/e9c714da8d57_init.py
Normal file
40
alembic_db/versions/e9c714da8d57_init.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""init
|
||||
|
||||
Revision ID: e9c714da8d57
|
||||
Revises:
|
||||
Create Date: 2025-05-30 20:14:33.772039
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e9c714da8d57'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.create_table('model',
|
||||
sa.Column('type', sa.Text(), nullable=False),
|
||||
sa.Column('path', sa.Text(), nullable=False),
|
||||
sa.Column('file_name', sa.Text(), nullable=True),
|
||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
||||
sa.Column('hash', sa.Text(), nullable=True),
|
||||
sa.Column('hash_algorithm', sa.Text(), nullable=True),
|
||||
sa.Column('source_url', sa.Text(), nullable=True),
|
||||
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('type', 'path')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('model')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,4 +1,11 @@
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
Text,
|
||||
DateTime,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -11,4 +18,42 @@ def to_dict(obj):
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
# TODO: Define models here
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
sqlalchemy model representing a model file in the system.
|
||||
|
||||
This class defines the database schema for storing information about model files,
|
||||
including their type, path, hash, and when they were added to the system.
|
||||
|
||||
Attributes:
|
||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
||||
path (Text): The file path of the model relative to the type folder (primary key)
|
||||
file_name (Text): The name of the model file
|
||||
file_size (Integer): The size of the model file in bytes
|
||||
hash (Text): A hash of the model file
|
||||
hash_algorithm (Text): The algorithm used to generate the hash
|
||||
source_url (Text): The URL of the model file
|
||||
date_added (DateTime): Timestamp of when the model was added to the system
|
||||
"""
|
||||
|
||||
__tablename__ = "model"
|
||||
|
||||
type = Column(Text, primary_key=True)
|
||||
path = Column(Text, primary_key=True)
|
||||
file_name = Column(Text)
|
||||
file_size = Column(Integer)
|
||||
hash = Column(Text)
|
||||
hash_algorithm = Column(Text)
|
||||
source_url = Column(Text)
|
||||
date_added = Column(DateTime, server_default=func.now())
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert the model instance to a dictionary representation.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the attributes of the model
|
||||
"""
|
||||
dict = to_dict(self)
|
||||
return dict
|
||||
|
||||
@@ -196,6 +196,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
"""
|
||||
A class to manage ComfyUI frontend versions and installations.
|
||||
|
||||
This class handles the initialization and management of different frontend versions,
|
||||
including the default frontend from the pip package and custom frontend versions
|
||||
from GitHub repositories.
|
||||
|
||||
Attributes:
|
||||
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
|
||||
"""
|
||||
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
@@ -205,6 +216,15 @@ class FrontendManager:
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the default frontend installation from the pip package.
|
||||
|
||||
Returns:
|
||||
str: The path to the default frontend static files.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-frontend-package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
|
||||
@@ -225,6 +245,15 @@ comfyui-frontend-package is not installed.
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the workflow templates.
|
||||
|
||||
Returns:
|
||||
str: The path to the workflow templates directory.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-workflow-templates package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@@ -260,11 +289,16 @@ comfyui-workflow-templates is not installed.
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Parse a version string into its components.
|
||||
|
||||
The version string should be in the format: 'owner/repo@version'
|
||||
where version can be either a semantic version (v1.2.3) or 'latest'.
|
||||
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
tuple[str, str, str]: A tuple containing (owner, repo, version).
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
@@ -281,18 +315,22 @@ comfyui-workflow-templates is not installed.
|
||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||
) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
Initialize a frontend version without error handling.
|
||||
|
||||
This method attempts to initialize a specific frontend version, either from
|
||||
the default pip package or from a custom GitHub repository. It will download
|
||||
and extract the frontend files if necessary.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
provider (FrontEndProvider, optional): The provider to use for custom frontends.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
Exception: If there is an error during initialization (e.g., network timeout,
|
||||
invalid URL, or missing assets).
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
check_frontend_version()
|
||||
@@ -344,13 +382,17 @@ comfyui-workflow-templates is not installed.
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
Initialize a frontend version with error handling.
|
||||
|
||||
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
|
||||
with error handling, falling back to the default frontend if initialization fails.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
str: The path to the initialized frontend. If initialization fails,
|
||||
returns the path to the default frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
|
||||
331
app/model_processor.py
Normal file
331
app/model_processor.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from folder_paths import get_relative_path, get_full_path
|
||||
from app.database.db import create_session, dependencies_available, can_create_session
|
||||
import blake3
|
||||
import comfy.utils
|
||||
|
||||
|
||||
if dependencies_available():
|
||||
from app.database.models import Model
|
||||
|
||||
|
||||
class ModelProcessor:
|
||||
def _validate_path(self, model_path):
|
||||
try:
|
||||
if not self._file_exists(model_path):
|
||||
logging.error(f"Model file not found: {model_path}")
|
||||
return None
|
||||
|
||||
result = get_relative_path(model_path)
|
||||
if not result:
|
||||
logging.error(
|
||||
f"Model file not in a recognized model directory: {model_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _file_exists(self, path):
|
||||
"""Check if a file exists."""
|
||||
return os.path.exists(path)
|
||||
|
||||
def _get_file_size(self, path):
|
||||
"""Get file size."""
|
||||
return os.path.getsize(path)
|
||||
|
||||
def _get_hasher(self):
|
||||
return blake3.blake3()
|
||||
|
||||
def _hash_file(self, model_path):
|
||||
try:
|
||||
hasher = self._get_hasher()
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
except Exception as e:
|
||||
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_existing_model(self, session, model_type, model_relative_path):
|
||||
return (
|
||||
session.query(Model)
|
||||
.filter(Model.type == model_type)
|
||||
.filter(Model.path == model_relative_path)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _ensure_source_url(self, session, model, source_url):
|
||||
if model.source_url is None:
|
||||
model.source_url = source_url
|
||||
session.commit()
|
||||
|
||||
def _update_database(
|
||||
self,
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
model,
|
||||
source_url,
|
||||
):
|
||||
try:
|
||||
if not model:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = Model(
|
||||
path=model_relative_path,
|
||||
type=model_type,
|
||||
file_name=os.path.basename(model_path),
|
||||
)
|
||||
session.add(model)
|
||||
|
||||
model.file_size = self._get_file_size(model_path)
|
||||
model.hash = model_hash
|
||||
if model_hash:
|
||||
model.hash_algorithm = "blake3"
|
||||
model.source_url = source_url
|
||||
|
||||
session.commit()
|
||||
return model
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error updating database for {model_relative_path}: {str(e)}"
|
||||
)
|
||||
|
||||
def process_file(self, model_path, source_url=None, model_hash=None):
|
||||
"""
|
||||
Process a model file and update the database with metadata.
|
||||
If the file already exists and matches the database, it will not be processed again.
|
||||
Returns the model object or if an error occurs, returns None.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
session.expire_on_commit = False
|
||||
|
||||
existing_model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if (
|
||||
existing_model
|
||||
and existing_model.hash
|
||||
and existing_model.file_size == self._get_file_size(model_path)
|
||||
):
|
||||
# File exists with hash and same size, no need to process
|
||||
self._ensure_source_url(session, existing_model, source_url)
|
||||
return existing_model
|
||||
|
||||
if model_hash:
|
||||
model_hash = model_hash.lower()
|
||||
logging.info(f"Using provided hash: {model_hash}")
|
||||
else:
|
||||
start_time = time.time()
|
||||
logging.info(f"Hashing model {model_relative_path}")
|
||||
model_hash = self._hash_file(model_path)
|
||||
if not model_hash:
|
||||
return
|
||||
logging.info(
|
||||
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
||||
)
|
||||
|
||||
return self._update_database(
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
existing_model,
|
||||
source_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
||||
"""
|
||||
Retrieve a model file from the database by hash and optionally by model type.
|
||||
Returns the model object or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
dispose_session = False
|
||||
|
||||
if session is None:
|
||||
session = create_session()
|
||||
dispose_session = True
|
||||
|
||||
model = session.query(Model).filter(Model.hash == model_hash)
|
||||
if model_type is not None:
|
||||
model = model.filter(Model.type == model_type)
|
||||
return model.first()
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
if dispose_session:
|
||||
session.close()
|
||||
|
||||
def retrieve_hash(self, model_path, model_type=None):
|
||||
"""
|
||||
Retrieve the hash of a model file from the database.
|
||||
Returns the hash or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
if model_type is not None:
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return None
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if model and model.hash:
|
||||
return model.hash
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _validate_file_extension(self, file_name):
|
||||
"""Validate that the file extension is supported."""
|
||||
extension = os.path.splitext(file_name)[1]
|
||||
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
||||
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
||||
|
||||
def _check_existing_file(self, model_type, file_name, expected_hash):
|
||||
"""Check if file exists and has correct hash."""
|
||||
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
||||
if self._file_exists(destination_path):
|
||||
model = self.process_file(destination_path)
|
||||
if model and (expected_hash is None or model.hash == expected_hash):
|
||||
logging.debug(
|
||||
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
||||
)
|
||||
return destination_path
|
||||
else:
|
||||
raise ValueError(
|
||||
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
||||
)
|
||||
return None
|
||||
|
||||
def _check_existing_file_by_hash(self, hash, type, url):
|
||||
"""Check if a file with the given hash exists in the database and on disk."""
|
||||
hash = hash.lower()
|
||||
with create_session() as session:
|
||||
model = self.retrieve_model_by_hash(hash, type, session)
|
||||
if model:
|
||||
existing_path = get_full_path(type, model.path)
|
||||
if existing_path:
|
||||
logging.debug(
|
||||
f"File {model.path} already exists in the database at {existing_path}"
|
||||
)
|
||||
self._ensure_source_url(session, model, url)
|
||||
return existing_path
|
||||
else:
|
||||
logging.debug(
|
||||
f"File {model.path} exists in the database but not on disk"
|
||||
)
|
||||
return None
|
||||
|
||||
def _download_file(self, url, destination_path, hasher):
|
||||
"""Download a file and update the hasher with its contents."""
|
||||
response = requests.get(url, stream=True)
|
||||
logging.info(f"Downloading {url} to {destination_path}")
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
if total_size > 0:
|
||||
pbar = comfy.utils.ProgressBar(total_size)
|
||||
else:
|
||||
pbar = None
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
||||
for chunk in response.iter_content(chunk_size=128 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
if pbar:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
||||
"""Verify that the downloaded file has the expected hash."""
|
||||
if expected_hash is not None and calculated_hash != expected_hash:
|
||||
self._remove_file(destination_path)
|
||||
raise ValueError(
|
||||
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
||||
)
|
||||
|
||||
def _remove_file(self, file_path):
|
||||
"""Remove a file from disk."""
|
||||
os.remove(file_path)
|
||||
|
||||
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
||||
"""
|
||||
Ensure a model file is downloaded and has the correct hash.
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
logging.debug(
|
||||
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
self._validate_file_extension(desired_file_name)
|
||||
|
||||
# Check if file exists with correct hash
|
||||
if hash:
|
||||
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Check if file exists locally
|
||||
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
||||
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Download the file
|
||||
hasher = self._get_hasher()
|
||||
self._download_file(url, destination_path, hasher)
|
||||
|
||||
# Verify hash
|
||||
calculated_hash = hasher.hexdigest()
|
||||
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
||||
|
||||
# Update database
|
||||
self.process_file(destination_path, url, calculated_hash)
|
||||
|
||||
# TODO: Notify frontend to reload models
|
||||
|
||||
return destination_path
|
||||
|
||||
|
||||
model_processor = ModelProcessor()
|
||||
@@ -210,6 +210,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -391,7 +391,6 @@ class WanModel(torch.nn.Module):
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
in_dim_ref_conv=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -485,11 +484,6 @@ class WanModel(torch.nn.Module):
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
if in_dim_ref_conv is not None:
|
||||
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
else:
|
||||
self.ref_conv = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
@@ -532,13 +526,6 @@ class WanModel(torch.nn.Module):
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
full_ref = None
|
||||
if self.ref_conv is not None:
|
||||
full_ref = kwargs.get("reference_latent", None)
|
||||
if full_ref is not None:
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@@ -565,9 +552,6 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
@@ -586,9 +570,6 @@ class WanModel(torch.nn.Module):
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
|
||||
@@ -1124,11 +1124,7 @@ class WAN21(BaseModel):
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
concat_mask_index = kwargs.get("concat_mask_index", 0)
|
||||
if concat_mask_index != 0:
|
||||
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
|
||||
else:
|
||||
return torch.cat((mask, image), dim=1)
|
||||
return torch.cat((mask, image), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1144,10 +1140,6 @@ class WAN21(BaseModel):
|
||||
if time_dim_concat is not None:
|
||||
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -373,11 +373,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||
if flf_weight is not None:
|
||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||
|
||||
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
|
||||
@@ -50,10 +50,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def is_html_file(file_path):
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read(100)
|
||||
return b"<!DOCTYPE html>" in content or b"<html" in content
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
@@ -66,6 +72,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if is_html_file(ckpt):
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
if "HeaderTooLarge" in message:
|
||||
@@ -93,6 +101,13 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
sd = pl_sd
|
||||
else:
|
||||
sd = pl_sd
|
||||
|
||||
try:
|
||||
from app.model_processor import model_processor
|
||||
model_processor.process_file(ckpt)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {ckpt}: {e}")
|
||||
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
@@ -12,7 +12,7 @@ import torch
|
||||
try:
|
||||
import torchaudio
|
||||
TORCH_AUDIO_AVAILABLE = True
|
||||
except:
|
||||
except ImportError:
|
||||
TORCH_AUDIO_AVAILABLE = False
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
@@ -1690,11 +1690,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
):
|
||||
self.validate_prompt(prompt, negative_prompt)
|
||||
|
||||
if image is None:
|
||||
image_type = None
|
||||
elif model_name == KlingImageGenModelName.kling_v1:
|
||||
raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
|
||||
else:
|
||||
if image is not None:
|
||||
image = tensor_to_base64_string(image)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
|
||||
@@ -103,63 +103,6 @@ class WanFunControlToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
class Wan22FunControlToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"ref_image": ("IMAGE", ),
|
||||
"control_video": ("IMAGE", ),
|
||||
# "start_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
|
||||
ref_latent = None
|
||||
if ref_image is not None:
|
||||
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
class WanFirstLastFrameToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -790,7 +733,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanTrackToVideo": WanTrackToVideo,
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"Wan22FunControlToVideo": Wan22FunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.50"
|
||||
__version__ = "0.3.49"
|
||||
|
||||
@@ -275,10 +275,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
||||
|
||||
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
"""
|
||||
Get the full path of a file in a folder, has to be a file
|
||||
"""
|
||||
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in folder_names_and_paths:
|
||||
@@ -291,6 +288,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
return full_path
|
||||
elif os.path.islink(full_path):
|
||||
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
|
||||
elif allow_missing:
|
||||
return full_path
|
||||
|
||||
return None
|
||||
|
||||
@@ -305,6 +304,27 @@ def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
return full_path
|
||||
|
||||
|
||||
def get_relative_path(full_path: str) -> tuple[str, str] | None:
|
||||
"""Convert a full path back to a type-relative path.
|
||||
|
||||
Args:
|
||||
full_path: The full path to the file
|
||||
|
||||
Returns:
|
||||
tuple[str, str] | None: A tuple of (model_type, relative_path) if found, None otherwise
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
full_path = os.path.normpath(full_path)
|
||||
|
||||
for model_type, (paths, _) in folder_names_and_paths.items():
|
||||
for base_path in paths:
|
||||
base_path = os.path.normpath(base_path)
|
||||
if full_path.startswith(base_path):
|
||||
relative_path = os.path.relpath(full_path, base_path)
|
||||
return model_type, relative_path
|
||||
|
||||
return None
|
||||
|
||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
global folder_names_and_paths
|
||||
|
||||
3
main.py
3
main.py
@@ -164,7 +164,6 @@ def cuda_malloc_warning():
|
||||
if cuda_malloc_warning:
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
@@ -279,7 +278,6 @@ def cleanup_temp():
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
@@ -288,7 +286,6 @@ def setup_database():
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
|
||||
def start_comfyui(asyncio_loop=None):
|
||||
"""
|
||||
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.50"
|
||||
version = "0.3.49"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.59
|
||||
comfyui-frontend-package==1.24.4
|
||||
comfyui-workflow-templates==0.1.53
|
||||
comfyui-embedded-docs==0.2.6
|
||||
torch
|
||||
torchsde
|
||||
@@ -20,6 +20,7 @@ tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
blake3
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
253
tests-unit/app_test/model_processor_test.py
Normal file
253
tests-unit/app_test/model_processor_test.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.model_processor import ModelProcessor
|
||||
from app.database.models import Model, Base
|
||||
import os
|
||||
|
||||
# Test data constants
|
||||
TEST_MODEL_TYPE = "checkpoints"
|
||||
TEST_URL = "http://example.com/model.safetensors"
|
||||
TEST_FILE_NAME = "model.safetensors"
|
||||
TEST_EXPECTED_HASH = "abc123"
|
||||
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
|
||||
|
||||
|
||||
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
|
||||
"""Helper to create a test model in the database."""
|
||||
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
return model
|
||||
|
||||
|
||||
def setup_mock_hash_calculation(model_processor, hash_value):
|
||||
"""Helper to setup hash calculation mocks."""
|
||||
mock_hash = MagicMock()
|
||||
mock_hash.hexdigest.return_value = hash_value
|
||||
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
|
||||
|
||||
|
||||
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
|
||||
"""Helper to verify model exists in database with correct attributes."""
|
||||
db_model = session.query(Model).filter_by(path=file_name).first()
|
||||
assert db_model is not None
|
||||
if expected_hash:
|
||||
assert db_model.hash == expected_hash
|
||||
if expected_type:
|
||||
assert db_model.type == expected_type
|
||||
return db_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
# Configure in-memory database
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
yield engine
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(db_engine):
|
||||
Session = sessionmaker(bind=db_engine)
|
||||
session = Session()
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_relative_path():
|
||||
with patch("app.model_processor.get_relative_path") as mock:
|
||||
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_full_path():
|
||||
with patch("app.model_processor.get_full_path") as mock:
|
||||
mock.return_value = TEST_DESTINATION_PATH
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
|
||||
with patch("app.model_processor.create_session", return_value=db_session):
|
||||
with patch("app.model_processor.can_create_session", return_value=True):
|
||||
processor = ModelProcessor()
|
||||
# Setup test state
|
||||
processor.removed_files = []
|
||||
processor.downloaded_files = []
|
||||
processor.file_exists = {}
|
||||
|
||||
def mock_download_file(url, destination_path, hasher):
|
||||
processor.downloaded_files.append((url, destination_path))
|
||||
processor.file_exists[destination_path] = True
|
||||
# Simulate writing some data to the file
|
||||
test_data = b"test data"
|
||||
hasher.update(test_data)
|
||||
|
||||
def mock_remove_file(file_path):
|
||||
processor.removed_files.append(file_path)
|
||||
if file_path in processor.file_exists:
|
||||
del processor.file_exists[file_path]
|
||||
|
||||
# Setup common patches
|
||||
file_exists_patch = patch.object(
|
||||
processor,
|
||||
"_file_exists",
|
||||
side_effect=lambda path: processor.file_exists.get(path, False),
|
||||
)
|
||||
file_size_patch = patch.object(
|
||||
processor,
|
||||
"_get_file_size",
|
||||
side_effect=lambda path: (
|
||||
1000 if processor.file_exists.get(path, False) else 0
|
||||
),
|
||||
)
|
||||
download_file_patch = patch.object(
|
||||
processor, "_download_file", side_effect=mock_download_file
|
||||
)
|
||||
remove_file_patch = patch.object(
|
||||
processor, "_remove_file", side_effect=mock_remove_file
|
||||
)
|
||||
|
||||
with (
|
||||
file_exists_patch,
|
||||
file_size_patch,
|
||||
download_file_patch,
|
||||
remove_file_patch,
|
||||
):
|
||||
yield processor
|
||||
|
||||
|
||||
def test_ensure_downloaded_invalid_extension(model_processor):
|
||||
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
|
||||
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
|
||||
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
|
||||
# Ensure that a file with the same hash but from a different source is not downloaded again
|
||||
SOURCE_URL = "https://example.com/other.sft"
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that a file with a different hash raises an error
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_downloaded_new_file(model_processor, db_session):
|
||||
# Ensure that a new file is downloaded
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
assert len(model_processor.downloaded_files) == 1
|
||||
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
|
||||
assert model_processor.file_exists[TEST_DESTINATION_PATH]
|
||||
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
|
||||
|
||||
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that download that results in a different hash raises an error
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, "different_hash"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Downloaded file hash .* does not match expected hash .*",
|
||||
):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE,
|
||||
TEST_URL,
|
||||
TEST_FILE_NAME,
|
||||
TEST_EXPECTED_HASH,
|
||||
)
|
||||
|
||||
assert len(model_processor.removed_files) == 1
|
||||
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
|
||||
assert TEST_DESTINATION_PATH not in model_processor.file_exists
|
||||
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
|
||||
|
||||
|
||||
def test_process_file_without_hash(model_processor, db_session):
|
||||
# Test processing file without provided hash
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash(model_processor, db_session):
|
||||
# Test retrieving model by hash
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
|
||||
# Test retrieving model by hash and type
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.type == TEST_MODEL_TYPE
|
||||
|
||||
|
||||
def test_retrieve_hash(model_processor, db_session):
|
||||
# Test retrieving hash for existing model
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
with patch.object(
|
||||
model_processor,
|
||||
"_validate_path",
|
||||
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
|
||||
):
|
||||
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
|
||||
assert result == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_validate_file_extension_valid_extensions(model_processor):
|
||||
# Test all valid file extensions
|
||||
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
|
||||
for ext in valid_extensions:
|
||||
model_processor._validate_file_extension(f"test{ext}") # Should not raise
|
||||
|
||||
|
||||
def test_process_file_existing_without_source_url(model_processor, db_session):
|
||||
# Test processing an existing file that needs its source URL updated
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
|
||||
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.source_url == TEST_URL
|
||||
|
||||
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
|
||||
assert db_model.source_url == TEST_URL
|
||||
Reference in New Issue
Block a user