mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 06:00:03 +00:00
Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db9e6edfa1 | ||
|
|
8af13b439b | ||
|
|
acd0e53653 | ||
|
|
117e7a5853 | ||
|
|
b3c0e4de57 | ||
|
|
ecaeeb990d | ||
|
|
c2b65e2fce | ||
|
|
fd5c0755af | ||
|
|
c881a1d689 | ||
|
|
a3b5d4996a | ||
|
|
c6238047ee | ||
|
|
5cd1113236 | ||
|
|
2f642d5d9b | ||
|
|
cd912963f1 | ||
|
|
6e4b1f9d00 | ||
|
|
dc202a2e51 | ||
|
|
153bc524bf | ||
|
|
393d2880dd | ||
|
|
4484b93d61 | ||
|
|
bd0e6825e8 | ||
|
|
ec0a832acb | ||
|
|
04c49a29b4 | ||
|
|
4609fcd260 | ||
|
|
6207f86c18 | ||
|
|
1dc3da6314 | ||
|
|
114fc73685 | ||
|
|
b48d6a83d4 | ||
|
|
027042db68 | ||
|
|
1a20656448 | ||
|
|
0f11869d55 | ||
|
|
5943fbf457 | ||
|
|
a60b7b86c5 | ||
|
|
2e9d51680a | ||
|
|
50d6e1caf4 | ||
|
|
ac12f77bed | ||
|
|
fcd9a236b0 | ||
|
|
21e8425087 | ||
|
|
b6c79a648a | ||
|
|
25bc1b5b57 | ||
|
|
3cd19e99c1 | ||
|
|
007b87e7ac | ||
|
|
34751fe9f9 | ||
|
|
1c705f7bfb | ||
|
|
48e5ea1dfd | ||
|
|
3cd7b32f1b | ||
|
|
c0c9720d77 | ||
|
|
fc0cb10bcb | ||
|
|
b7d7cc1d49 | ||
|
|
79e94544bd | ||
|
|
ce0000c4f2 | ||
|
|
c5cfb34c07 | ||
|
|
edee33f55e | ||
|
|
2c03884f5f | ||
|
|
6e9ee55cdd | ||
|
|
023cf13721 | ||
|
|
c3566c0d76 | ||
|
|
c3c3e93c5b | ||
|
|
6ffc159bdd | ||
|
|
96e0d0924e | ||
|
|
e14f3b6610 | ||
|
|
1618002411 | ||
|
|
6ef85c4915 | ||
|
|
6da00dd899 | ||
|
|
4f3f9e72a9 | ||
|
|
d157c3299d | ||
|
|
d1b9822f74 | ||
|
|
f2b002372b | ||
|
|
38d0493825 | ||
|
|
acbf08cd60 | ||
|
|
53e762a3af | ||
|
|
9a552df898 | ||
|
|
f2fda021ab | ||
|
|
303b1735f8 | ||
|
|
9e5f677746 | ||
|
|
65cfcf5b1b | ||
|
|
1bdc9a947f | ||
|
|
d622a61874 | ||
|
|
236b9e211d | ||
|
|
6ca3d5c011 | ||
|
|
0be8a76c93 | ||
|
|
0357ed7ec4 |
@@ -1,3 +1,3 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
|
||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@@ -117,7 +117,7 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||
|
||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||
rm requirements_comfyui.txt
|
||||
|
||||
|
||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
4
.github/workflows/test-launch.yml
vendored
4
.github/workflows/test-launch.yml
vendored
@@ -32,7 +32,9 @@ jobs:
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
|
||||
cat console_output_filtered.log
|
||||
if grep -qE "Exception|Error" console_output_filtered.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
@@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
|
||||
174
alembic_db/versions/0001_assets.py
Normal file
174
alembic_db/versions/0001_assets.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Initial assets schema
|
||||
Revision ID: 0001_assets
|
||||
Revises: None
|
||||
Create Date: 2025-12-10 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
102
app/assets/api/routes.py
Normal file
102
app/assets/api/routes.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import logging
|
||||
import uuid
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list assets.
|
||||
"""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = manager.list_assets(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get an asset's info as JSON.
|
||||
"""
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list all tags based on query parameters.
|
||||
"""
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as e:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = manager.list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
94
app/assets/api/schemas_in.py
Normal file
94
app/assets/api/schemas_in.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
prefix: str | None = Field(None, min_length=1, max_length=256)
|
||||
limit: int = Field(100, ge=1, le=1000)
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
def normalize_prefix(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
60
app/assets/api/schemas_out.py
Normal file
60
app/assets/api/schemas_out.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
204
app/assets/database/bulk_ops.py
Normal file
204
app/assets/database/bulk_ops.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
)
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
233
app/assets/database/models.py
Normal file
233
app/assets/database/models.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
267
app/assets/database/queries.py
Normal file
267
app/assets/database/queries.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
62
app/assets/database/tags.py
Normal file
62
app/assets/database/tags.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
75
app/assets/hashing.py
Normal file
75
app/assets/hashing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
217
app/assets/helpers.py
Normal file
217
app/assets/helpers.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import contextlib
|
||||
import os
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
Normalize a list of tags by:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
123
app/assets/manager.py
Normal file
123
app/assets/manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
)
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
def list_assets(
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
229
app/assets/scanner.py
Normal file
229
app/assets/scanner.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
@@ -1,14 +1,21 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
Base = declarative_base()
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
# TODO: Define models here
|
||||
|
||||
@@ -44,7 +44,7 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
@@ -55,7 +55,7 @@ class ModelFileManager:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
||||
@@ -231,6 +231,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,6 +1,59 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
import math
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
|
||||
def scale_dim(size, scale):
|
||||
scaled = math.ceil(size * scale / patch_size) * patch_size
|
||||
return max(patch_size, int(scaled))
|
||||
|
||||
# Binary search for optimal scale
|
||||
lo, hi = eps / 10, 100.0
|
||||
while hi - lo >= eps:
|
||||
mid = (lo + hi) / 2
|
||||
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
|
||||
if (h // patch_size) * (w // patch_size) <= max_num_patches:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
return scale_dim(oh, lo), scale_dim(ow, lo)
|
||||
|
||||
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
|
||||
if size > 0:
|
||||
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
|
||||
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
@@ -156,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
|
||||
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
|
||||
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
|
||||
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
|
||||
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
|
||||
return embeds + embed_weight
|
||||
|
||||
class Siglip2Embeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, pixel_values):
|
||||
b, c, h, w = pixel_values.shape
|
||||
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
|
||||
img = img.permute(0, 1, 3, 2, 4, 5)
|
||||
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
|
||||
img = self.patch_embedding(img)
|
||||
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
@@ -199,8 +273,11 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
if model_type in ["siglip2_vision_model"]:
|
||||
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -17,28 +16,12 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
@@ -50,9 +33,10 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.model_type = config.get("model_type", "clip_vision_model")
|
||||
self.config = config.copy()
|
||||
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||
if self.model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
@@ -73,7 +57,10 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
if self.model_type == "siglip2_vision_model":
|
||||
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
else:
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
@@ -125,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
|
||||
if len(patch_embedding_shape) == 2:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
|
||||
else:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
|
||||
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": -1,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip2_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"num_patches": 256,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
118
comfy/float.py
118
comfy/float.py
@@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
|
||||
# TODO: improve this?
|
||||
def stochastic_float_to_fp4_e2m1(x, generator):
|
||||
orig_shape = x.shape
|
||||
sign = torch.signbit(x).to(torch.uint8)
|
||||
|
||||
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
|
||||
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
|
||||
|
||||
x = x.abs()
|
||||
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
|
||||
|
||||
mantissa = torch.where(
|
||||
exp > 0,
|
||||
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x * 2.0),
|
||||
out=x
|
||||
).round().to(torch.uint8)
|
||||
del x
|
||||
|
||||
exp = exp.to(torch.uint8)
|
||||
|
||||
fp4 = (sign << 3) | (exp << 1) | mantissa
|
||||
del sign, exp, mantissa
|
||||
|
||||
fp4_flat = fp4.view(-1)
|
||||
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
|
||||
return packed.reshape(list(orig_shape)[:-1] + [-1])
|
||||
|
||||
|
||||
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
if (rows, cols) != (padded_rows, padded_cols):
|
||||
padded = torch.zeros(
|
||||
(padded_rows, padded_cols),
|
||||
device=input_matrix.device,
|
||||
dtype=input_matrix.dtype,
|
||||
)
|
||||
padded[:rows, :cols] = input_matrix
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
if flatten:
|
||||
return rearranged.flatten()
|
||||
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), dim=-1)
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
|
||||
|
||||
# Handle zero blocks (from padding): avoid 0/0 NaN
|
||||
zero_scale_mask = (total_scale == 0)
|
||||
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
|
||||
|
||||
x = x / total_scale_safe.unsqueeze(-1)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
|
||||
|
||||
x = x.view(orig_shape)
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
return data_lp, blocked_scales
|
||||
|
||||
@@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
else:
|
||||
break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
|
||||
@@ -407,6 +407,11 @@ class LTXV(LatentFormat):
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class LTXAV(LTXV):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = None
|
||||
self.latent_rgb_factors_bias = None
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
@@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
|
||||
@@ -4,6 +4,7 @@ from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
@@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
@@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@@ -3,7 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management, model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
@@ -102,13 +103,13 @@ UPSAMPLERS = {
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.load_device = comfy.model_management.vae_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.dtype = comfy.model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
@@ -117,5 +118,5 @@ class HunyuanVideo15SRModel():
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
|
||||
913
comfy/ldm/lightricks/av_model.py
Normal file
913
comfy/ldm/lightricks/av_model.py
Normal file
@@ -0,0 +1,913 @@
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
"""Expand back to original tensor."""
|
||||
if self.patches_per_frame == 1:
|
||||
return self.data
|
||||
|
||||
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||
|
||||
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# No compression - compute directly
|
||||
if self.patches_per_frame == 1:
|
||||
num_tokens = self.data.shape[1]
|
||||
dim_per_param = self.feature_dim // num_ada_params
|
||||
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
# Compressed: compute on per-frame data then expand spatially
|
||||
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||
device=self.data.device, dtype=self.data.dtype
|
||||
)
|
||||
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||
|
||||
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||
return tuple(
|
||||
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||
for frame_val in frame_ada
|
||||
)
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
v_dim,
|
||||
a_dim,
|
||||
v_heads,
|
||||
a_heads,
|
||||
vd_head,
|
||||
ad_head,
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn1 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=v_context_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn2 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=a_context_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Video, K,V: Audio
|
||||
self.audio_to_video_attn = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Audio, K,V: Video
|
||||
self.video_to_audio_attn = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=v_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.audio_ff = FeedForward(
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
if isinstance(timestep, CompressedTimestep):
|
||||
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||
).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
def get_av_ca_ada_values(
|
||||
self,
|
||||
scale_shift_table: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale_shift_timestep: torch.Tensor,
|
||||
gate_timestep: torch.Tensor,
|
||||
num_scale_shift_values: int = 4,
|
||||
):
|
||||
scale_shift_ada_values = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :],
|
||||
batch_size,
|
||||
scale_shift_timestep,
|
||||
)
|
||||
gate_ada_values = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :],
|
||||
batch_size,
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
|
||||
vx, ax = x
|
||||
run_ax = run_ax and ax.numel() > 0
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
|
||||
|
||||
return vx, ax
|
||||
|
||||
|
||||
class LTXAVModel(LTXVModel):
|
||||
"""LTXAV model for audio-video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
audio_in_channels=128,
|
||||
cross_attention_dim=4096,
|
||||
audio_cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
audio_attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
audio_num_attention_heads=32,
|
||||
caption_channels=3840,
|
||||
num_layers=48,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Store audio-specific parameters
|
||||
self.audio_in_channels = audio_in_channels
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
self.audio_out_channels = audio_in_channels
|
||||
|
||||
# Audio-specific constants
|
||||
self.num_audio_channels = 8
|
||||
self.audio_frequency_bins = 16
|
||||
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXAV-specific components."""
|
||||
# Audio-specific projections
|
||||
self.audio_patchify_proj = self.operations.Linear(
|
||||
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicAVTransformerBlock(
|
||||
v_dim=self.inner_dim,
|
||||
a_dim=self.audio_inner_dim,
|
||||
v_heads=self.num_attention_heads,
|
||||
a_heads=self.audio_num_attention_heads,
|
||||
vd_head=self.attention_head_dim,
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXAV."""
|
||||
# Video output components
|
||||
super()._init_output_components(device, dtype)
|
||||
# Audio output components
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||
)
|
||||
self.audio_norm_out = self.operations.LayerNorm(
|
||||
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.audio_proj_out = self.operations.Linear(
|
||||
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||
)
|
||||
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||
|
||||
def separate_audio_and_video_latents(self, x, audio_length):
|
||||
"""Separate audio and video latents from combined input."""
|
||||
# vx = x[:, : self.in_channels]
|
||||
# ax = x[:, self.in_channels :]
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||
#
|
||||
# ax = ax.reshape(
|
||||
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||
# )
|
||||
|
||||
vx = x[0]
|
||||
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||
device=vx.device, dtype=vx.dtype
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||
if ax.numel() == 0:
|
||||
return vx
|
||||
else:
|
||||
return [vx, ax]
|
||||
"""Recombine audio and video latents for output."""
|
||||
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# # pad to f x h x w of the video latents
|
||||
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||
# if target_shape is None:
|
||||
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||
# else:
|
||||
# repetitions = target_shape[1] - vx.shape[1]
|
||||
# padded_len = repetitions * divisor
|
||||
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||
# return torch.cat([vx, ax], dim=1)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
timestep_flat = timestep_scaled.flatten()
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
v_pixel_coords = pixel_coords[0]
|
||||
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||
|
||||
a_latent_coords = pixel_coords[1]
|
||||
a_pe = self._precompute_freqs_cis(
|
||||
a_latent_coords,
|
||||
dim=self.audio_inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||
max_pos = max(
|
||||
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||
)
|
||||
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||
v_pixel_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||
a_latent_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context = context[0]
|
||||
a_context = context[1]
|
||||
v_timestep = timestep[0]
|
||||
a_timestep = timestep[1]
|
||||
v_pe, av_cross_video_freq_cis = pe[0]
|
||||
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||
|
||||
(
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(
|
||||
args["img"],
|
||||
v_context=args["v_context"],
|
||||
a_context=args["a_context"],
|
||||
attention_mask=args["attention_mask"],
|
||||
v_timestep=args["v_timestep"],
|
||||
a_timestep=args["a_timestep"],
|
||||
v_pe=args["v_pe"],
|
||||
a_pe=args["a_pe"],
|
||||
v_cross_pe=args["v_cross_pe"],
|
||||
a_cross_pe=args["a_cross_pe"],
|
||||
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": (vx, ax),
|
||||
"v_context": v_context,
|
||||
"a_context": a_context,
|
||||
"attention_mask": attention_mask,
|
||||
"v_timestep": v_timestep,
|
||||
"a_timestep": a_timestep,
|
||||
"v_pe": v_pe,
|
||||
"a_pe": a_pe,
|
||||
"v_cross_pe": av_cross_video_freq_cis,
|
||||
"a_cross_pe": av_cross_audio_freq_cis,
|
||||
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
vx, ax = out["img"]
|
||||
else:
|
||||
vx, ax = block(
|
||||
(vx, ax),
|
||||
v_context=v_context,
|
||||
a_context=a_context,
|
||||
attention_mask=attention_mask,
|
||||
v_timestep=v_timestep,
|
||||
a_timestep=a_timestep,
|
||||
v_pe=v_pe,
|
||||
a_pe=a_pe,
|
||||
v_cross_pe=av_cross_video_freq_cis,
|
||||
a_cross_pe=av_cross_audio_freq_cis,
|
||||
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
a_scale_shift_values = (
|
||||
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||
+ a_embedded_timestep[:, :, None]
|
||||
)
|
||||
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||
|
||||
ax = self.audio_norm_out(ax)
|
||||
ax = ax * (1 + a_scale) + a_shift
|
||||
ax = self.audio_proj_out(ax)
|
||||
|
||||
# Unpatchify audio
|
||||
ax = self.a_patchifier.unpatchify(
|
||||
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||
)
|
||||
|
||||
# Recombine audio and video
|
||||
original_shape = kwargs.get("av_orig_shape")
|
||||
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask=None,
|
||||
frame_rate=25,
|
||||
transformer_options={},
|
||||
keyframe_idxs=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass for LTXAV model.
|
||||
|
||||
Args:
|
||||
x: Combined audio-video input tensor
|
||||
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments including audio_length
|
||||
|
||||
Returns:
|
||||
Combined audio-video output tensor
|
||||
"""
|
||||
# Handle timestep format
|
||||
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||
v_timestep, a_timestep = timestep
|
||||
kwargs["a_timestep"] = a_timestep
|
||||
timestep = v_timestep
|
||||
else:
|
||||
kwargs["a_timestep"] = timestep
|
||||
|
||||
# Call parent forward method
|
||||
return super().forward(
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask,
|
||||
frame_rate,
|
||||
transformer_options,
|
||||
keyframe_idxs,
|
||||
**kwargs,
|
||||
)
|
||||
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import torch
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
generate_freq_grid_np,
|
||||
interleaved_freqs_cis,
|
||||
split_freqs_cis,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BasicTransformerBlock1D(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
||||
qk_norm (`str`, *optional*, defaults to None):
|
||||
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
||||
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
||||
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
||||
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
||||
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dim_out=dim,
|
||||
glu=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
|
||||
# 1. Normalization Before Self-Attention
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 2. Self-Attention
|
||||
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Normalization before Feed-Forward
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
num_attention_heads=30,
|
||||
num_layers=2,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
split_rope=False,
|
||||
double_precision_rope=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_rope = split_rope
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock1D(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
||||
for i in range(1)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
def precompute_freqs(self, indices_grid, spacing):
|
||||
source_dtype = indices_grid.dtype
|
||||
dtype = (
|
||||
torch.float32
|
||||
if source_dtype in (torch.bfloat16, torch.float16)
|
||||
else source_dtype
|
||||
)
|
||||
|
||||
fractional_positions = self.get_fractional_positions(indices_grid)
|
||||
indices = (
|
||||
generate_freq_grid_np(
|
||||
self.positional_embedding_theta,
|
||||
indices_grid.shape[1],
|
||||
self.inner_dim,
|
||||
)
|
||||
if self.double_precision_rope
|
||||
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
||||
).to(device=fractional_positions.device)
|
||||
|
||||
if spacing == "exp_2":
|
||||
freqs = (
|
||||
(indices * fractional_positions.unsqueeze(-1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
else:
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
def generate_freq_grid(self, spacing, dtype, device):
|
||||
dim = self.inner_dim
|
||||
theta = self.positional_embedding_theta
|
||||
n_pos_dims = 1
|
||||
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
||||
start = 1
|
||||
end = theta
|
||||
|
||||
if spacing == "exp":
|
||||
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
||||
indices = indices.to(dtype=dtype, device=device)
|
||||
elif spacing == "exp_2":
|
||||
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
||||
indices = indices.to(dtype=dtype)
|
||||
elif spacing == "linear":
|
||||
indices = torch.linspace(
|
||||
start, end, dim // n_elem, device=device, dtype=dtype
|
||||
)
|
||||
elif spacing == "sqrt":
|
||||
indices = torch.linspace(
|
||||
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
||||
).sqrt()
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
if self.split_rope:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(
|
||||
freqs, pad_size, self.num_attention_heads
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
|
||||
if self.num_learnable_registers:
|
||||
num_registers_duplications = math.ceil(
|
||||
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||
)
|
||||
learnable_registers = torch.tile(
|
||||
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
||||
)
|
||||
|
||||
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
|
||||
indices_grid = torch.arange(
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
# if self.output_scale is not None:
|
||||
# hidden_states = hidden_states / self.output_scale
|
||||
|
||||
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||
if float(scale) not in mapping:
|
||||
raise ValueError(
|
||||
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
|
||||
)
|
||||
return mapping[float(scale)]
|
||||
|
||||
|
||||
class PixelShuffleND(nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
p3=self.upscale_factors[2],
|
||||
)
|
||||
elif self.dims == 2:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
)
|
||||
elif self.dims == 1:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1) f h w -> b c (f p1) h w",
|
||||
p1=self.upscale_factors[0],
|
||||
)
|
||||
|
||||
|
||||
class BlurDownsample(nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int):
|
||||
super().__init__()
|
||||
assert dims in (2, 3)
|
||||
assert stride >= 1 and isinstance(stride, int)
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
|
||||
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
|
||||
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (5,5)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
|
||||
# x2d: (B, C, H, W)
|
||||
B, C, H, W = x2d.shape
|
||||
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
|
||||
x2d = F.conv2d(
|
||||
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
|
||||
)
|
||||
return x2d
|
||||
|
||||
if self.dims == 2:
|
||||
return _apply_2d(x)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = _apply_2d(x)
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(nn.Module):
|
||||
"""
|
||||
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||
|
||||
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int, scale: float):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.num, self.den = _rational_for_scale(self.scale)
|
||||
self.conv = nn.Conv2d(
|
||||
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
|
||||
)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
||||
):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = nn.GroupNorm(32, channels)
|
||||
self.activation = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class LatentUpsampler(nn.Module):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input latent
|
||||
mid_channels (`int`): Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 512,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
spatial_scale: float = 2.0,
|
||||
rational_resampler: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.rational_resampler = rational_resampler
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = nn.SiLU()
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(
|
||||
mid_channels=mid_channels, scale=self.spatial_scale
|
||||
)
|
||||
else:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either spatial_upsample or temporal_upsample must be True"
|
||||
)
|
||||
|
||||
self.post_upsample_res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = latent.shape
|
||||
|
||||
if self.dims == 2:
|
||||
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.upsampler(x)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
else:
|
||||
x = self.initial_conv(latent)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.temporal_upsample:
|
||||
x = self.upsampler(x)
|
||||
x = x[:, :, 1:, :, :]
|
||||
else:
|
||||
if isinstance(self.upsampler, SpatialRationalResampler):
|
||||
x = self.upsampler(x)
|
||||
else:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.upsampler(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(
|
||||
in_channels=config.get("in_channels", 4),
|
||||
mid_channels=config.get("mid_channels", 128),
|
||||
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
||||
dims=config.get("dims", 2),
|
||||
spatial_upsample=config.get("spatial_upsample", True),
|
||||
temporal_upsample=config.get("temporal_upsample", False),
|
||||
spatial_scale=config.get("spatial_scale", 2.0),
|
||||
rational_resampler=config.get("rational_resampler", False),
|
||||
)
|
||||
|
||||
def config(self):
|
||||
return {
|
||||
"_class_name": "LatentUpsampler",
|
||||
"in_channels": self.in_channels,
|
||||
"mid_channels": self.mid_channels,
|
||||
"num_blocks_per_stage": self.num_blocks_per_stage,
|
||||
"dims": self.dims,
|
||||
"spatial_upsample": self.spatial_upsample,
|
||||
"temporal_upsample": self.temporal_upsample,
|
||||
"spatial_scale": self.spatial_scale,
|
||||
"rational_resampler": self.rational_resampler,
|
||||
}
|
||||
@@ -1,13 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
class LTXRopeType(str, Enum):
|
||||
INTERLEAVED = "interleaved"
|
||||
SPLIT = "split"
|
||||
|
||||
KEY = "rope_type"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.INTERLEAVED
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
class LTXFrequenciesPrecision(str, Enum):
|
||||
FLOAT32 = "float32"
|
||||
FLOAT64 = "float64"
|
||||
|
||||
KEY = "frequencies_precision"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.FLOAT32
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@@ -39,9 +73,7 @@ def get_timestep_embedding(
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
@@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(
|
||||
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
@@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
size_emb_dim,
|
||||
use_additional_conditions: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
@@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
embedding_dim,
|
||||
size_emb_dim=embedding_dim // 3,
|
||||
use_additional_conditions=use_additional_conditions,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
@@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(
|
||||
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
@@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis):
|
||||
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
||||
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
||||
return (
|
||||
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
if split_pe else
|
||||
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
)
|
||||
|
||||
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
needs_reshape = False
|
||||
if input_tensor.ndim != 4 and cos.ndim == 4:
|
||||
B, H, T, _ = cos.shape
|
||||
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
||||
needs_reshape = True
|
||||
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
||||
first_half_input = split_input[..., :1, :]
|
||||
second_half_input = split_input[..., 1:, :]
|
||||
output = split_input * cos.unsqueeze(-2)
|
||||
first_half_output = output[..., :1, :]
|
||||
second_half_output = output[..., 1:, :]
|
||||
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
||||
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
||||
output = rearrange(output, "... d r -> ... (d r)")
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
@@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
@@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
@@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
||||
axis=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32
|
||||
device = indices_grid.device
|
||||
@functools.lru_cache(maxsize=5)
|
||||
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
||||
theta = positional_embedding_theta
|
||||
start = 1
|
||||
end = theta
|
||||
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
pow_indices = np.power(
|
||||
theta,
|
||||
np.linspace(
|
||||
_log_base(start, theta),
|
||||
_log_base(end, theta),
|
||||
inner_dim // n_elem,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
||||
|
||||
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
||||
theta = positional_embedding_theta
|
||||
start = 1
|
||||
end = theta
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
inner_dim // n_elem,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=torch.float32)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
||||
if use_middle_indices_grid:
|
||||
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
||||
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
||||
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid.shape) == 4:
|
||||
indices_grid = indices_grid[..., 0]
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||
indices = indices.to(device=fractional_positions.device)
|
||||
|
||||
# Compute frequencies and apply cos/sin
|
||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
# Pad if dim is not divisible by 6
|
||||
if dim % 6 != 0:
|
||||
padding_size = dim % 6
|
||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||
def interleaved_freqs_cis(freqs, pad_size):
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq, sin_freq
|
||||
|
||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
||||
cos_freq = freqs.cos()
|
||||
sin_freq = freqs.sin()
|
||||
|
||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||
freqs_cis = torch.stack([
|
||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
return freqs_cis
|
||||
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape freqs to be compatible with multi-head attention
|
||||
B , T, half_HD = cos_freq.shape
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
return cos_freq, sin_freq
|
||||
|
||||
class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""
|
||||
Abstract base class for LTX models (Lightricks Transformer models).
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
This class defines the common interface and shared functionality for all LTX models,
|
||||
including LTXV (video) and LTXAV (audio-video) variants.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
cross_attention_dim: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
caption_channels: int,
|
||||
num_layers: int,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list = [20, 2048, 2048],
|
||||
causal_temporal_positioning: bool = False,
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.vae_scale_factors = vae_scale_factors
|
||||
self.use_middle_indices_grid = use_middle_indices_grid
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.caption_channels = caption_channels
|
||||
self.num_layers = num_layers
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
||||
self.freq_grid_generator = (
|
||||
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
||||
else generate_freq_grid_pytorch
|
||||
)
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels
|
||||
|
||||
# Initialize common components
|
||||
self._init_common_components(device, dtype)
|
||||
|
||||
# Initialize model-specific components
|
||||
self._init_model_components(device, dtype, **kwargs)
|
||||
|
||||
# Initialize transformer blocks
|
||||
self._init_transformer_blocks(device, dtype, **kwargs)
|
||||
|
||||
# Initialize output components
|
||||
self._init_output_components(device, dtype)
|
||||
|
||||
def _init_common_components(self, device, dtype):
|
||||
"""Initialize components common to all LTX models
|
||||
- patchify_proj: Linear projection for patchifying input
|
||||
- adaln_single: AdaLN layer for timestep embedding
|
||||
- caption_projection: Linear projection for caption embedding
|
||||
"""
|
||||
self.patchify_proj = self.operations.Linear(
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize model-specific components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
self,
|
||||
indices_grid,
|
||||
dim,
|
||||
out_dtype,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=False,
|
||||
num_attention_heads=32,
|
||||
):
|
||||
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
||||
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
||||
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||
|
||||
if split_mode:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||
else:
|
||||
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
||||
n_elem = 2 * indices_grid.shape[1]
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
"""Prepare positional embeddings."""
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
pe = self._precompute_freqs_cis(
|
||||
fractional_coords,
|
||||
dim=self.inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
return pe
|
||||
|
||||
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
||||
"""Prepare attention mask."""
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
) * torch.finfo(x_dtype).max
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(
|
||||
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
||||
),
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Internal forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
input_dtype = x[0].dtype
|
||||
batch_size = x[0].shape[0]
|
||||
else:
|
||||
input_dtype = x.dtype
|
||||
batch_size = x.shape[0]
|
||||
# Process input
|
||||
merged_args = {**transformer_options, **kwargs}
|
||||
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
)
|
||||
|
||||
# Process output
|
||||
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
||||
return x
|
||||
|
||||
|
||||
class LTXVModel(LTXBaseModel):
|
||||
"""LTXV model for video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXV."""
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
self.norm_out = self.operations.LayerNorm(
|
||||
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXV."""
|
||||
additional_args = {"orig_shape": list(x.shape)}
|
||||
x, latent_coords = self.patchifier.patchify(x)
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords=latent_coords,
|
||||
@@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
|
||||
causal_fix=self.causal_temporal_positioning,
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
additional_args.update({"grid_mask": grid_mask})
|
||||
x = x[:, grid_mask, :]
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
return x, pixel_coords, additional_args
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
@@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
return x
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output for LTXV."""
|
||||
# Apply scale-shift modulation
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = torch.addcmul(x, x, scale).add_(shift)
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
full_x[:, grid_mask, :] = x
|
||||
x = full_x
|
||||
# Unpatchify to restore original dimensions
|
||||
orig_shape = kwargs["orig_shape"]
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
|
||||
@@ -21,20 +21,23 @@ def latent_to_pixel_coords(
|
||||
Returns:
|
||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||
"""
|
||||
shape = [1] * latent_coords.ndim
|
||||
shape[1] = -1
|
||||
pixel_coords = (
|
||||
latent_coords
|
||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
|
||||
)
|
||||
if causal_fix:
|
||||
# Fix temporal scale for first frame to 1 due to causality
|
||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
return pixel_coords
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
def __init__(self, patch_size: int, start_end: bool=False):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
self.start_end = start_end
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
@@ -71,11 +74,23 @@ class Patchifier(ABC):
|
||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = rearrange(
|
||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
|
||||
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
|
||||
latent_sample_coords_end = latent_sample_coords_start + delta
|
||||
|
||||
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_start = rearrange(
|
||||
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
if self.start_end:
|
||||
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_end = rearrange(
|
||||
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
|
||||
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
|
||||
else:
|
||||
latent_coords = latent_sample_coords_start
|
||||
return latent_coords
|
||||
|
||||
|
||||
@@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
|
||||
class AudioPatchifier(Patchifier):
|
||||
def __init__(self, patch_size: int,
|
||||
sample_rate=16000,
|
||||
hop_length=160,
|
||||
audio_latent_downsample_factor=4,
|
||||
is_causal=True,
|
||||
start_end=False,
|
||||
shift = 0
|
||||
):
|
||||
super().__init__(patch_size, start_end=start_end)
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self.shift = shift
|
||||
|
||||
def copy_with_shift(self, shift):
|
||||
return AudioPatchifier(
|
||||
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
|
||||
self.is_causal, self.start_end, shift
|
||||
)
|
||||
|
||||
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
|
||||
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
||||
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
||||
if self.is_causal:
|
||||
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
|
||||
return audio_mel_frame * self.hop_length / self.sample_rate
|
||||
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# audio_latents: (batch, channels, time, freq)
|
||||
b, _, t, _ = audio_latents.shape
|
||||
audio_latents = rearrange(
|
||||
audio_latents,
|
||||
"b c t f -> b t (c f)",
|
||||
)
|
||||
|
||||
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
|
||||
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
if self.start_end:
|
||||
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
|
||||
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
|
||||
else:
|
||||
audio_latents_timings = audio_latents_start_timings
|
||||
return audio_latents, audio_latents_timings
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
|
||||
# audio_latents: (batch, time, freq * channels)
|
||||
audio_latents = rearrange(
|
||||
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
|
||||
)
|
||||
return audio_latents
|
||||
|
||||
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils as utils
|
||||
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioVAEComponentConfig:
|
||||
"""Container for model component configuration extracted from metadata."""
|
||||
|
||||
autoencoder: dict
|
||||
vocoder: dict
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
||||
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
||||
|
||||
raw_config = metadata["config"]
|
||||
if isinstance(raw_config, str):
|
||||
parsed_config = json.loads(raw_config)
|
||||
else:
|
||||
parsed_config = raw_config
|
||||
|
||||
audio_config = parsed_config.get("audio_vae")
|
||||
vocoder_config = parsed_config.get("vocoder")
|
||||
|
||||
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
||||
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
||||
|
||||
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||
|
||||
|
||||
class ModelDeviceManager:
|
||||
"""Manages device placement and GPU residency for the composed model."""
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||
|
||||
def ensure_model_loaded(self) -> None:
|
||||
comfy.model_management.free_memory(
|
||||
self.patcher.model_size(),
|
||||
self.patcher.load_device,
|
||||
)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
|
||||
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(self.patcher.load_device)
|
||||
|
||||
@property
|
||||
def load_device(self):
|
||||
return self.patcher.load_device
|
||||
|
||||
|
||||
class AudioLatentNormalizer:
|
||||
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||
|
||||
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
||||
self.patchifier = patchfier
|
||||
self.statistics = statistics_processor
|
||||
|
||||
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
normalized = self.statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
||||
|
||||
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
denormalized = self.statistics.un_normalize(patched)
|
||||
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
||||
|
||||
|
||||
class AudioPreprocessor:
|
||||
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
||||
|
||||
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.mel_bins = mel_bins
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.n_fft = n_fft
|
||||
|
||||
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
||||
if source_rate == self.target_sample_rate:
|
||||
return waveform
|
||||
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def normalize_amplitude(
|
||||
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||
) -> torch.Tensor:
|
||||
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||
peak = torch.max(torch.abs(waveform)) + eps
|
||||
scale = peak.clamp(max=max_amplitude) / peak
|
||||
return waveform * scale
|
||||
|
||||
def waveform_to_mel(
|
||||
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||
) -> torch.Tensor:
|
||||
waveform = self.resample(waveform, waveform_sample_rate)
|
||||
waveform = self.normalize_amplitude(waveform)
|
||||
|
||||
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.target_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.n_fft,
|
||||
hop_length=self.mel_hop_length,
|
||||
f_min=0.0,
|
||||
f_max=self.target_sample_rate / 2.0,
|
||||
n_mels=self.mel_bins,
|
||||
window_fn=torch.hann_window,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
power=1.0,
|
||||
mel_scale="slaney",
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
|
||||
mel = mel_transform(waveform)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return mel.permute(0, 1, 3, 2).contiguous()
|
||||
|
||||
|
||||
class AudioVAE(torch.nn.Module):
|
||||
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||
|
||||
def __init__(self, state_dict: dict, metadata: dict):
|
||||
super().__init__()
|
||||
|
||||
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||
|
||||
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=autoencoder_config["sampling_rate"],
|
||||
hop_length=autoencoder_config["mel_hop_length"],
|
||||
is_causal=autoencoder_config["is_causal"],
|
||||
),
|
||||
self.autoencoder.per_channel_statistics,
|
||||
)
|
||||
|
||||
self.preprocessor = AudioPreprocessor(
|
||||
target_sample_rate=autoencoder_config["sampling_rate"],
|
||||
mel_bins=autoencoder_config["mel_bins"],
|
||||
mel_hop_length=autoencoder_config["mel_hop_length"],
|
||||
n_fft=autoencoder_config["n_fft"],
|
||||
)
|
||||
|
||||
self.device_manager = ModelDeviceManager(self)
|
||||
|
||||
def encode(self, audio: dict) -> torch.Tensor:
|
||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||
|
||||
waveform = audio["waveform"]
|
||||
waveform_sample_rate = audio["sample_rate"]
|
||||
input_device = waveform.device
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
)
|
||||
|
||||
latents = self.autoencoder.encode(mel_spec)
|
||||
posterior = DiagonalGaussianDistribution(latents)
|
||||
latent_mode = posterior.mode()
|
||||
|
||||
normalized = self.normalizer.normalize(latent_mode)
|
||||
return normalized.to(input_device)
|
||||
|
||||
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode normalized latent tensors into an audio waveform."""
|
||||
original_shape = latents.shape
|
||||
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
latents = self.device_manager.move_to_load_device(latents)
|
||||
latents = self.normalizer.denormalize(latents)
|
||||
|
||||
target_shape = self.target_shape_from_latents(original_shape)
|
||||
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||
|
||||
waveform = self.run_vocoder(mel_spec)
|
||||
return self.device_manager.move_to_load_device(waveform)
|
||||
|
||||
def target_shape_from_latents(self, latents_shape):
|
||||
batch, _, time, _ = latents_shape
|
||||
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
||||
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
||||
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
||||
return (
|
||||
batch,
|
||||
self.autoencoder.decoder.out_ch,
|
||||
target_length,
|
||||
self.autoencoder.mel_bins,
|
||||
)
|
||||
|
||||
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
||||
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
||||
|
||||
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
audio_channels = self.autoencoder.decoder.out_ch
|
||||
vocoder_input = mel_spec.transpose(2, 3)
|
||||
|
||||
if audio_channels == 1:
|
||||
vocoder_input = vocoder_input.squeeze(1)
|
||||
elif audio_channels != 2:
|
||||
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||
|
||||
return self.vocoder(vocoder_input)
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return int(self.autoencoder.sampling_rate)
|
||||
|
||||
@property
|
||||
def mel_hop_length(self) -> int:
|
||||
return int(self.autoencoder.mel_hop_length)
|
||||
|
||||
@property
|
||||
def mel_bins(self) -> int:
|
||||
return int(self.autoencoder.mel_bins)
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return int(self.autoencoder.decoder.z_channels)
|
||||
|
||||
@property
|
||||
def latent_frequency_bins(self) -> int:
|
||||
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
||||
|
||||
@property
|
||||
def latents_per_second(self) -> float:
|
||||
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
@property
|
||||
def output_sample_rate(self) -> int:
|
||||
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
||||
if output_rate is not None:
|
||||
return int(output_rate)
|
||||
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
||||
if upsample_factor is None:
|
||||
raise AttributeError(
|
||||
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
||||
)
|
||||
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.device_manager.patcher.model_size()
|
||||
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
@@ -0,0 +1,909 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from .pixel_norm import PixelNorm
|
||||
import comfy.ops
|
||||
import logging
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class StringConvertibleEnum(Enum):
|
||||
"""
|
||||
Base enum class that provides string-to-enum conversion functionality.
|
||||
|
||||
This mixin adds a str_to_enum() class method that handles conversion from
|
||||
strings, None, or existing enum instances with case-insensitive matching.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def str_to_enum(cls, value):
|
||||
"""
|
||||
Convert a string, enum instance, or None to the appropriate enum member.
|
||||
|
||||
Args:
|
||||
value: Can be an enum instance of this class, a string, or None
|
||||
|
||||
Returns:
|
||||
Enum member of this class
|
||||
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to a valid enum member
|
||||
"""
|
||||
# Already an enum instance of this class
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
|
||||
# None maps to NONE member if it exists
|
||||
if value is None:
|
||||
if hasattr(cls, "NONE"):
|
||||
return cls.NONE
|
||||
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
|
||||
|
||||
# String conversion (case-insensitive)
|
||||
if isinstance(value, str):
|
||||
value_lower = value.lower()
|
||||
|
||||
# Try to match against enum values
|
||||
for member in cls:
|
||||
# Handle members with None values
|
||||
if member.value is None:
|
||||
if value_lower == "none":
|
||||
return member
|
||||
# Handle members with string values
|
||||
elif isinstance(member.value, str) and member.value.lower() == value_lower:
|
||||
return member
|
||||
|
||||
# Build helpful error message with valid values
|
||||
valid_values = []
|
||||
for member in cls:
|
||||
if member.value is None:
|
||||
valid_values.append("none")
|
||||
elif isinstance(member.value, str):
|
||||
valid_values.append(member.value)
|
||||
|
||||
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
|
||||
|
||||
raise ValueError(
|
||||
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
|
||||
f"Expected string, None, or {cls.__name__} instance."
|
||||
)
|
||||
|
||||
|
||||
class AttentionType(StringConvertibleEnum):
|
||||
"""Enum for specifying the attention mechanism type."""
|
||||
|
||||
VANILLA = "vanilla"
|
||||
LINEAR = "linear"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class CausalityAxis(StringConvertibleEnum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
|
||||
|
||||
def Normalize(in_channels, *, num_groups=32, normtype="group"):
|
||||
if normtype == "group":
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif normtype == "pixel":
|
||||
return PixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution.
|
||||
|
||||
This layer ensures that the output at time `t` only depends on inputs
|
||||
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||
to the time dimension (width) before the convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
# Ensure kernel_size and dilation are tuples
|
||||
kernel_size = nn.modules.utils._pair(kernel_size)
|
||||
dilation = nn.modules.utils._pair(dilation)
|
||||
|
||||
# Calculate padding dimensions
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.HEIGHT:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
# The internal convolution layer uses no padding, as we handle it manually
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Apply causal padding before convolution
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: Optional[CausalityAxis] = None,
|
||||
):
|
||||
"""
|
||||
Create a 2D convolution layer that can be either causal or non-causal.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels
|
||||
kernel_size: Size of the convolution kernel
|
||||
stride: Convolution stride
|
||||
padding: Padding (if None, will be calculated based on causal flag)
|
||||
dilation: Dilation rate
|
||||
groups: Number of groups for grouped convolution
|
||||
bias: Whether to use bias
|
||||
causality_axis: Dimension along which to apply causality.
|
||||
|
||||
Returns:
|
||||
Either a regular Conv2d or CausalConv2d layer
|
||||
"""
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int):
|
||||
padding = kernel_size // 2
|
||||
else:
|
||||
padding = tuple(k // 2 for k in kernel_size)
|
||||
return ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||
# So the output elements rely on the following windows:
|
||||
# 0: [-,-,0]
|
||||
# 1: [-,0,0]
|
||||
# 2: [0,0,1]
|
||||
# 3: [0,1,1]
|
||||
# 4: [1,1,2]
|
||||
# 5: [1,2,2]
|
||||
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||
# while all other elements rely on two elements in the input.
|
||||
# So we can drop the first element to undo the padding (rather than the last element).
|
||||
# This is a no-op for non-causal convolutions.
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pass # x remains unchanged
|
||||
case CausalityAxis.HEIGHT:
|
||||
x = x[:, :, 1:, :]
|
||||
case CausalityAxis.WIDTH:
|
||||
x = x[:, :, :, 1:]
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pass # x remains unchanged
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer that can use either a strided convolution
|
||||
or average pooling. Supports standard and causal padding for the
|
||||
convolutional mode.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
# (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pad = (0, 1, 0, 1)
|
||||
case CausalityAxis.WIDTH:
|
||||
pad = (2, 0, 0, 1)
|
||||
case CausalityAxis.HEIGHT:
|
||||
pad = (0, 1, 2, 0)
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pad = (1, 0, 0, 1)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
norm_type="group",
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels, normtype=norm_type)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels, normtype=norm_type)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, norm_type="group"):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels, normtype=norm_type)
|
||||
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w).contiguous()
|
||||
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
||||
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
||||
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w).contiguous()
|
||||
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w).contiguous()
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
|
||||
# Convert string to enum if needed
|
||||
attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
if attn_type != AttentionType.NONE:
|
||||
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
|
||||
else:
|
||||
logging.info(f"making identity attention with {in_channels} in_channels")
|
||||
|
||||
match attn_type:
|
||||
case AttentionType.VANILLA:
|
||||
return AttnBlock(in_channels, norm_type=norm_type)
|
||||
case AttentionType.NONE:
|
||||
return nn.Identity(in_channels)
|
||||
case AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
case _:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.z_channels = z_channels
|
||||
self.double_z = double_z
|
||||
self.norm_type = norm_type
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = make_conv2d(
|
||||
in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the encoder.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape [batch, channels, time, n_mels]
|
||||
|
||||
Returns:
|
||||
Encoded latent representation
|
||||
"""
|
||||
feature_maps = [self.conv_in(x)]
|
||||
|
||||
# Process each resolution level (from high to low resolution)
|
||||
for resolution_level in range(self.num_resolutions):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
# Apply ResNet block with optional timestep embedding
|
||||
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
|
||||
|
||||
# Apply attention if configured for this resolution level
|
||||
if len(self.down[resolution_level].attn) > 0:
|
||||
current_features = self.down[resolution_level].attn[block_idx](current_features)
|
||||
|
||||
# Store processed features
|
||||
feature_maps.append(current_features)
|
||||
|
||||
# Downsample spatial dimensions (except at the final resolution level)
|
||||
if resolution_level != self.num_resolutions - 1:
|
||||
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
|
||||
feature_maps.append(downsampled_features)
|
||||
|
||||
# === MIDDLE PROCESSING PHASE ===
|
||||
# Take the lowest resolution features for middle processing
|
||||
bottleneck_features = feature_maps[-1]
|
||||
|
||||
# Apply first middle ResNet block
|
||||
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
|
||||
|
||||
# Apply middle attention block
|
||||
bottleneck_features = self.mid.attn_1(bottleneck_features)
|
||||
|
||||
# Apply second middle ResNet block
|
||||
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
|
||||
|
||||
# === OUTPUT PHASE ===
|
||||
# Normalize the bottleneck features
|
||||
output_features = self.norm_out(bottleneck_features)
|
||||
|
||||
# Apply non-linearity (SiLU activation)
|
||||
output_features = self.non_linearity(output_features)
|
||||
|
||||
# Final convolution to produce latent representation
|
||||
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
|
||||
return self.conv_out(output_features)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = out_ch
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.norm_type = norm_type
|
||||
self.z_channels = z_channels
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def _adjust_output_shape(self, decoded_output, target_shape):
|
||||
"""
|
||||
Adjust output shape to match target dimensions for variable-length audio.
|
||||
|
||||
This function handles the common case where decoded audio spectrograms need to be
|
||||
resized to match a specific target shape.
|
||||
|
||||
Args:
|
||||
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
||||
target_shape: Target shape tuple (batch, channels, time, frequency)
|
||||
|
||||
Returns:
|
||||
Tensor adjusted to match target_shape exactly
|
||||
"""
|
||||
# Current output shape: (batch, channels, time, frequency)
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
_, target_channels, target_time, target_freq = target_shape
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
# Step 3: Apply padding if needed
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
||||
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0), # frequency padding (left, right)
|
||||
0,
|
||||
max(time_padding_needed, 0), # time padding (top, bottom)
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
# Step 4: Final safety crop to ensure exact target shape
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"ch": self.ch,
|
||||
"out_ch": self.out_ch,
|
||||
"ch_mult": self.ch_mult,
|
||||
"num_res_blocks": self.num_res_blocks,
|
||||
"in_channels": self.in_channels,
|
||||
"resolution": self.resolution,
|
||||
"z_channels": self.z_channels,
|
||||
}
|
||||
|
||||
def forward(self, latent_features, target_shape=None):
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
|
||||
Args:
|
||||
latent_features: Encoded latent representation of shape (batch, channels, height, width)
|
||||
target_shape: Optional target output shape (batch, channels, time, frequency)
|
||||
If provided, output will be cropped/padded to match this shape
|
||||
|
||||
Returns:
|
||||
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
||||
"""
|
||||
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
|
||||
|
||||
# Transform latent features to decoder's internal feature dimension
|
||||
hidden_features = self.conv_in(latent_features)
|
||||
|
||||
# Middle processing
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
# Upsampling
|
||||
# Progressively increase spatial resolution from lowest to highest
|
||||
for resolution_level in reversed(range(self.num_resolutions)):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_index in range(self.num_res_blocks + 1):
|
||||
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
|
||||
|
||||
if len(self.up[resolution_level].attn) > 0:
|
||||
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
|
||||
|
||||
if resolution_level != 0:
|
||||
hidden_features = self.up[resolution_level].upsample(hidden_features)
|
||||
|
||||
# Output
|
||||
if self.give_pre_end:
|
||||
# Return intermediate features before final processing (for debugging/analysis)
|
||||
decoded_output = hidden_features
|
||||
else:
|
||||
# Standard output path: normalize, activate, and convert to output channels
|
||||
# Final normalization layer
|
||||
hidden_features = self.norm_out(hidden_features)
|
||||
|
||||
# Apply SiLU (Swish) activation function
|
||||
hidden_features = self.non_linearity(hidden_features)
|
||||
|
||||
# Final convolution to map to output channels (typically 2 for stereo audio)
|
||||
decoded_output = self.conv_out(hidden_features)
|
||||
|
||||
# Optional tanh activation to bound output values to [-1, 1] range
|
||||
if self.tanh_out:
|
||||
decoded_output = torch.tanh(decoded_output)
|
||||
|
||||
# Adjust shape for audio data
|
||||
if target_shape is not None:
|
||||
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||
|
||||
|
||||
class CausalAudioAutoencoder(nn.Module):
|
||||
def __init__(self, config=None):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
self.encoder = Encoder(**encoder_config)
|
||||
self.decoder = Decoder(**decoder_config)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"mel_bins": self.mel_bins,
|
||||
"mel_hop_length": self.mel_hop_length,
|
||||
"n_fft": self.n_fft,
|
||||
"causality_axis": self.causality_axis.value,
|
||||
"is_causal": self.is_causal,
|
||||
}
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x, target_shape=None):
|
||||
return self.decoder(x, target_shape=target_shape)
|
||||
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
super(Vocoder, self).__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
x: Input spectrogram tensor. Can be:
|
||||
- 3D: (batch_size, channels, time_steps) for mono
|
||||
- 4D: (batch_size, 2, channels, time_steps) for stereo
|
||||
|
||||
Returns:
|
||||
Audio tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
if x.dim() == 4: # stereo
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
@@ -30,6 +30,13 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
@@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if (q.device.type != "cuda" or
|
||||
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||
mask is not None):
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
B, H, L, D = q.shape
|
||||
if H != heads:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
B, N, inner_dim = q.shape
|
||||
if inner_dim % heads != 0:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
dim_head = inner_dim // heads
|
||||
|
||||
if dim_head >= 256 or N <= 1024:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||
except Exception as e:
|
||||
exception_fallback = True
|
||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||
|
||||
if exception_fallback:
|
||||
if not skip_reshape:
|
||||
del q_s, k_s, v_s
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
if not skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
pass
|
||||
else:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
|
||||
return out
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
@@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
|
||||
@@ -394,7 +394,8 @@ class Model(nn.Module):
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@@ -548,7 +549,8 @@ class Encoder(nn.Module):
|
||||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
||||
@@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
@@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
||||
@@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
||||
@@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
@@ -946,7 +947,7 @@ class GenmoMochi(BaseModel):
|
||||
|
||||
class LTXV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -977,6 +978,60 @@ class LTXV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class LTXAV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
|
||||
audio_denoise_mask = None
|
||||
if denoise_mask is not None and "latent_shapes" in kwargs:
|
||||
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
|
||||
if len(denoise_mask) > 1:
|
||||
audio_denoise_mask = denoise_mask[1]
|
||||
denoise_mask = denoise_mask[0]
|
||||
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
|
||||
if audio_denoise_mask is not None:
|
||||
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
|
||||
|
||||
keyframe_idxs = kwargs.get("keyframe_idxs", None)
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
latent_shapes = kwargs.get("latent_shapes", None)
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
v_timestep = timestep
|
||||
a_timestep = timestep
|
||||
|
||||
if denoise_mask is not None:
|
||||
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
|
||||
if audio_denoise_mask is not None:
|
||||
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
|
||||
|
||||
return v_timestep, a_timestep
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
@@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
@@ -305,7 +307,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
|
||||
@@ -22,7 +22,6 @@ from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
import sys
|
||||
import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
@@ -349,15 +348,27 @@ try:
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
|
||||
def aotriton_supported(gpu_arch):
|
||||
path = torch.__path__[0]
|
||||
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
|
||||
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
|
||||
if gpu_arch in gfx:
|
||||
return True
|
||||
if "{}x".format(gpu_arch[:-1]) in gfx:
|
||||
return True
|
||||
if "{}xx".format(gpu_arch[:-2]) in gfx:
|
||||
return True
|
||||
return False
|
||||
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
@@ -456,7 +467,7 @@ def module_size(module):
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
@@ -1156,7 +1167,7 @@ def pin_memory(tensor):
|
||||
if not tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
size = tensor.nbytes
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
@@ -1183,7 +1194,7 @@ def unpin_memory(tensor):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
size = tensor.nbytes
|
||||
|
||||
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||
if size_stored is None:
|
||||
@@ -1504,6 +1515,16 @@ def supports_fp8_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_nvfp4_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
@@ -1542,6 +1563,10 @@ def soft_empty_cache(force=False):
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
@@ -718,6 +718,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
cast_weight = self.force_cast_weights
|
||||
m.comfy_force_cast_weights = self.force_cast_weights
|
||||
if lowvram_weight:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.weight_function = []
|
||||
@@ -790,11 +791,12 @@ class ModelPatcher:
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
204
comfy/ops.py
204
comfy/ops.py
@@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
if isinstance(input, QuantizedTensor):
|
||||
dtype = input._layout_params["orig_dtype"]
|
||||
dtype = input.params.orig_dtype
|
||||
else:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
@@ -412,26 +412,34 @@ def fp8_linear(self, input):
|
||||
return None
|
||||
|
||||
input_dtype = input.dtype
|
||||
input_shape = input.shape
|
||||
tensor_3d = input.ndim == 3
|
||||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
if tensor_3d:
|
||||
input = input.reshape(-1, input_shape[2])
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
if input.ndim != 2:
|
||||
return None
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input_fp8 = input.to(dtype).contiguous()
|
||||
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
||||
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
||||
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
return o
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
return None
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
if tensor_3d:
|
||||
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
|
||||
|
||||
return o
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
@@ -477,14 +485,20 @@ if CUBLAS_IS_AVAILABLE:
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
from .quant_ops import (
|
||||
QuantizedTensor,
|
||||
QUANT_ALGOS,
|
||||
TensorCoreFP8Layout,
|
||||
get_layout_class,
|
||||
)
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
_disabled = disabled
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
@@ -497,21 +511,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if dtype is None:
|
||||
dtype = MixedPrecisionOps._compute_dtype
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._has_bias = bias
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
@@ -520,7 +546,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
@@ -529,49 +556,61 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
dtype = self.factory_kwargs["dtype"]
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
||||
if dtype != MixedPrecisionOps._compute_dtype:
|
||||
self.comfy_cast_weights = True
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(weight_scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.to(device)
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
# Load format-specific parameters
|
||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||
# FP8: single tensor scale
|
||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
|
||||
if scale is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
params = layout_cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
if tensor_scale is None or block_scale is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=tensor_scale,
|
||||
block_scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue # Already handled above
|
||||
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
@@ -586,13 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
input_scale = getattr(self, 'input_scale', None)
|
||||
if input_scale is not None:
|
||||
sd["{}input_scale".format(prefix)] = input_scale
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
@@ -607,12 +662,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
input_shape = input.shape
|
||||
reshaped_3d = False
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return output
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
@@ -622,7 +698,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
@@ -649,10 +726,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@@ -1,580 +1,174 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
QuantizedTensor,
|
||||
QuantizedLayout,
|
||||
TensorCoreFP8Layout as _CKFp8Layout,
|
||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||
register_layout_op,
|
||||
register_layout_class,
|
||||
get_layout_class,
|
||||
)
|
||||
_CK_AVAILABLE = True
|
||||
if torch.version.cuda is None:
|
||||
ck.registry.disable("cuda")
|
||||
else:
|
||||
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||
_CK_AVAILABLE = False
|
||||
|
||||
class QuantizedTensor:
|
||||
pass
|
||||
|
||||
class _CKFp8Layout:
|
||||
pass
|
||||
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
|
||||
import comfy.float
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
"""
|
||||
Decorator to register a layout-specific operation handler.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
Example:
|
||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||
def fp8_linear(func, args, kwargs):
|
||||
# FP8-specific linear implementation
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
if torch_op not in _LAYOUT_REGISTRY:
|
||||
_LAYOUT_REGISTRY[torch_op] = {}
|
||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_generic_util(torch_op):
|
||||
"""
|
||||
Decorator to register a generic utility that works for all layouts.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||
|
||||
Example:
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
# Works for any layout
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
_GENERIC_UTILS[torch_op] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_layout_from_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg._layout_type
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for item in arg:
|
||||
if isinstance(item, QuantizedTensor):
|
||||
return item._layout_type
|
||||
return None
|
||||
|
||||
|
||||
def _move_layout_params_to_device(params, device):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.to(device=device)
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def _copy_layout_params(params):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.clone()
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||
for k, v in src.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
dst[k].copy_(v, non_blocking=non_blocking)
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Base class for quantization layouts.
|
||||
|
||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||
|
||||
New quantization formats should subclass this and implement the required methods.
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||
|
||||
|
||||
class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Universal quantized tensor that works with any layout.
|
||||
|
||||
This tensor subclass uses a pluggable layout system to support multiple
|
||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||
|
||||
The layout_type determines format-specific behavior, while common operations
|
||||
(detach, clone, to) are handled generically.
|
||||
|
||||
Attributes:
|
||||
_qdata: The quantized tensor data
|
||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
@property
|
||||
def layout_type(self):
|
||||
return self._layout_type
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
"""
|
||||
inner_tensors = ["_qdata"]
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
|
||||
tensor_params = {}
|
||||
non_tensor_params = {}
|
||||
for k, v in self._layout_params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tensor_params[k] = v
|
||||
else:
|
||||
non_tensor_params[k] = v
|
||||
|
||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||
ctx["non_tensor_params"] = non_tensor_params
|
||||
|
||||
for k, v in tensor_params.items():
|
||||
attr_name = f"_layout_param_{k}"
|
||||
object.__setattr__(self, attr_name, v)
|
||||
inner_tensors.append(attr_name)
|
||||
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||
"""
|
||||
Tensor unflattening protocol for proper device movement.
|
||||
Reconstructs the QuantizedTensor after device movement.
|
||||
"""
|
||||
layout_type = ctx["layout_type"]
|
||||
layout_params = dict(ctx["non_tensor_params"])
|
||||
|
||||
for key in ctx["tensor_param_keys"]:
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||
return cls(qdata, layout_type, layout_params)
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||
if func in _GENERIC_UTILS:
|
||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||
|
||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||
layout_type = _get_layout_from_args(args)
|
||||
if layout_type and func in _LAYOUT_REGISTRY:
|
||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||
def dequant_arg(arg):
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg.dequantize()
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return type(arg)(dequant_arg(a) for a in arg)
|
||||
return arg
|
||||
|
||||
new_args = dequant_arg(args)
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._qdata.data_ptr()
|
||||
|
||||
def is_pinned(self):
|
||||
return self._qdata.is_pinned()
|
||||
|
||||
def is_contiguous(self, *arg, **kwargs):
|
||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||
|
||||
def storage(self):
|
||||
return self._qdata.storage()
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# FP8 Layouts with Comfy-Specific Extensions
|
||||
# ==============================================================================
|
||||
|
||||
def _create_transformed_qtensor(qt, transform_fn):
|
||||
new_data = transform_fn(qt._qdata)
|
||||
new_params = _copy_layout_params(qt._layout_params)
|
||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
FP8_DTYPE = None # Must be overridden in subclass
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
f"but not supported. Ignoring layout."
|
||||
)
|
||||
|
||||
# Handle device transfer
|
||||
current_device = qt._qdata.device
|
||||
if target_device is not None:
|
||||
# Normalize device for comparison
|
||||
if isinstance(target_device, str):
|
||||
target_device = torch.device(target_device)
|
||||
if isinstance(current_device, str):
|
||||
current_device = torch.device(current_device)
|
||||
|
||||
if target_device != current_device:
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
if target_dtype is not None:
|
||||
new_params["orig_dtype"] = target_dtype
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||
return qt
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.clone.default)
|
||||
def generic_clone(func, args, kwargs):
|
||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||
def generic_to_copy(func, args, kwargs):
|
||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
op_name="_to_copy"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||
def generic_to_dtype_layout(func, args, kwargs):
|
||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
target_layout=kwargs.get('layout', None),
|
||||
op_name="to"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.copy_.default)
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
non_blocking = args[2] if len(args) > 2 else False
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
return qt_dest
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype)
|
||||
def generic_to_dtype(func, args, kwargs):
|
||||
"""Handle .to(dtype) calls - dtype conversion only."""
|
||||
src = args[0]
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
||||
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
||||
src._layout_params["orig_dtype"] = target_dtype
|
||||
return src
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||
def generic_empty_like(func, args, kwargs):
|
||||
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
# Create empty tensor with same shape and dtype as the quantized data
|
||||
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||
|
||||
# Handle device transfer for layout params
|
||||
target_device = kwargs.get('device', new_qdata.device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
|
||||
# Update orig_dtype if dtype is specified
|
||||
new_params['orig_dtype'] = hp_dtype
|
||||
|
||||
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layout + Operation Handlers
|
||||
# ==============================================================================
|
||||
class TensorCoreFP8Layout(QuantizedLayout):
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if cls.FP8_DTYPE is None:
|
||||
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if isinstance(scale, str) and scale == "recalculate":
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
|
||||
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||
tensor_info = torch.finfo(tensor.dtype)
|
||||
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||
|
||||
if scale is not None:
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
if scale is None:
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
|
||||
else:
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
|
||||
|
||||
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
|
||||
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
lp_amax = torch.finfo(dtype).max
|
||||
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
||||
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
||||
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return tensor, layout_params
|
||||
params = cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
block_scale=block_scale,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
return qtensor._qdata, qtensor._layout_params['scale']
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e5m2
|
||||
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Registry
|
||||
# ==============================================================================
|
||||
|
||||
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
|
||||
},
|
||||
"float8_e5m2": {
|
||||
"storage_t": torch.float8_e5m2,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
|
||||
},
|
||||
"nvfp4": {
|
||||
"storage_t": torch.uint8,
|
||||
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
|
||||
"group_size": 16,
|
||||
},
|
||||
}
|
||||
|
||||
LAYOUTS = {
|
||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||
}
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
# ==============================================================================
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||
def fp8_linear(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
out_dtype = kwargs.get("out_dtype")
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
weight_t = plain_weight.t()
|
||||
|
||||
tensor_2d = False
|
||||
if len(plain_input.shape) == 2:
|
||||
tensor_2d = True
|
||||
plain_input = plain_input.unsqueeze(1)
|
||||
|
||||
input_shape = plain_input.shape
|
||||
if len(input_shape) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
output = torch._scaled_mm(
|
||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
|
||||
if not tensor_2d:
|
||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||
|
||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
output_scale = scale_a * scale_b
|
||||
output_params = {
|
||||
'scale': output_scale,
|
||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||
}
|
||||
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||
else:
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
|
||||
# Case 2: DQ Fallback
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
input_tensor = input_tensor.dequantize()
|
||||
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
plain_input.contiguous(),
|
||||
plain_weight,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||
def fp8_addmm(func, args, kwargs):
|
||||
input_tensor = args[1]
|
||||
weight = args[2]
|
||||
bias = args[0]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
if isinstance(args[2], QuantizedTensor):
|
||||
a[2] = args[2].dequantize()
|
||||
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||
def fp8_mm(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||
def fp8_func(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
ar = list(args)
|
||||
ar[0] = plain_input
|
||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||
return func(*args, **kwargs)
|
||||
__all__ = [
|
||||
"QuantizedTensor",
|
||||
"QuantizedLayout",
|
||||
"TensorCoreFP8Layout",
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
28
comfy/sd.py
28
comfy/sd.py
@@ -218,7 +218,7 @@ class CLIP:
|
||||
if unprojected:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
@@ -266,7 +266,7 @@ class CLIP:
|
||||
if return_pooled == "unprojected":
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
@@ -299,8 +299,11 @@ class CLIP:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
def load_model(self, tokens={}):
|
||||
memory_used = 0
|
||||
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
||||
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
return self.patcher
|
||||
|
||||
def get_key_patches(self):
|
||||
@@ -476,8 +479,8 @@ class VAE:
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.upscale_index_formula = (8, 32, 32)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
@@ -1041,7 +1044,8 @@ class TEModel(Enum):
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
QWEN3_2B = 17
|
||||
JINA_CLIP_2 = 18
|
||||
GEMMA_3_12B = 18
|
||||
JINA_CLIP_2 = 19
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1055,9 +1059,9 @@ def detect_te_model(sd):
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
if weight.shape[0] == 10240:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
elif weight.shape[0] == 5120:
|
||||
return TEModel.T5_XL
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
@@ -1067,6 +1071,8 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
@@ -1271,6 +1277,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.NEWBIE:
|
||||
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||
|
||||
@@ -836,6 +836,21 @@ class LTXV(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||
|
||||
class LTXAV(LTXV):
|
||||
unet_config = {
|
||||
"image_model": "ltxav",
|
||||
}
|
||||
|
||||
latent_format = latent_formats.LTXAV
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 0.061 # TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXAV(self, device=device)
|
||||
return out
|
||||
|
||||
class HunyuanVideo(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
@@ -1536,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -154,7 +154,8 @@ class TAEHV(nn.Module):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
@@ -167,5 +168,6 @@ class TAEHV(nn.Module):
|
||||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
|
||||
@@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CosmosTEModel_
|
||||
|
||||
@@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return MochiTEModel_
|
||||
|
||||
@@ -7,8 +7,8 @@ import math
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.clip_model
|
||||
|
||||
import comfy.model_management
|
||||
from . import qwen_vl
|
||||
|
||||
@dataclass
|
||||
@@ -189,6 +189,31 @@ class Gemma3_4B_Config:
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
vocab_size: int = 262208
|
||||
hidden_size: int = 3840
|
||||
intermediate_size: int = 15360
|
||||
num_hidden_layers: int = 48
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 131072
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta = [1000000.0, 10000.0]
|
||||
transformer_type: str = "gemma3"
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
super().__init__()
|
||||
@@ -521,6 +546,41 @@ class Llama2_(nn.Module):
|
||||
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
self.mm_input_projection_weight = nn.Parameter(
|
||||
torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"])
|
||||
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
||||
|
||||
def forward(self, vision_outputs: torch.Tensor):
|
||||
batch_size, _, seq_length = vision_outputs.shape
|
||||
|
||||
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||
batch_size, seq_length, self.patches_per_image, self.patches_per_image
|
||||
)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||
|
||||
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||
|
||||
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||
|
||||
projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype))
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
class BaseLlama:
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
@@ -637,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_12B_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.image_size = config.vision_config["image_size"]
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||
return None, None
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
from transformers import T5TokenizerFast
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.genmo
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -16,3 +20,123 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||
|
||||
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||
|
||||
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
||||
text = llama_template.format(text)
|
||||
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
||||
embed_count = 0
|
||||
for k in text_tokens:
|
||||
tt = text_tokens[k]
|
||||
for r in tt:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 262144:
|
||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return text_tokens
|
||||
|
||||
class LTXAVTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||
self.dtypes.add(dtype_llama)
|
||||
|
||||
operations = self.gemma3_12b.operations # TODO
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
self.gemma3_12b.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.gemma3_12b.reset_clip_options()
|
||||
self.execution_device = None
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||
|
||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
out = out.movedim(1, -1).to(self.execution_device)
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
return out.to(out_device), pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma3_12b.load_sd(sd)
|
||||
else:
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant /= 2.0
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class LTXAVTEModel_(LTXAVTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return LTXAVTEModel_
|
||||
|
||||
@@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
|
||||
@@ -1198,7 +1198,7 @@ def unpack_latents(combined_latent, latent_shapes):
|
||||
combined_latent = combined_latent[:, :, cut:]
|
||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
output_tensors = [combined_latent]
|
||||
return output_tensors
|
||||
|
||||
def detect_layer_quantization(state_dict, prefix):
|
||||
@@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
|
||||
@@ -26,7 +26,6 @@ if TYPE_CHECKING:
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from ._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG
|
||||
|
||||
@@ -76,16 +75,6 @@ class NumberDisplay(str, Enum):
|
||||
slider = "slider"
|
||||
|
||||
|
||||
class _StringIOType(str):
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
class _ComfyType(ABC):
|
||||
Type = Any
|
||||
io_type: str = None
|
||||
@@ -125,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
|
||||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
# assign ComfyType attributes, if needed
|
||||
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
||||
new_cls.io_type = _StringIOType(io_type)
|
||||
new_cls.io_type = io_type
|
||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||
new_cls.Input.Parent = new_cls
|
||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||
@@ -165,7 +153,7 @@ class Input(_IO_V3):
|
||||
'''
|
||||
Base class for a V3 Input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__()
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
@@ -173,6 +161,7 @@ class Input(_IO_V3):
|
||||
self.tooltip = tooltip
|
||||
self.lazy = lazy
|
||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||
self.rawLink = raw_link
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
@@ -180,10 +169,11 @@ class Input(_IO_V3):
|
||||
"optional": self.optional,
|
||||
"tooltip": self.tooltip,
|
||||
"lazy": self.lazy,
|
||||
"rawLink": self.rawLink,
|
||||
}) | prune_dict(self.extra_dict)
|
||||
|
||||
def get_io_type(self):
|
||||
return _StringIOType(self.io_type)
|
||||
return self.io_type
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self]
|
||||
@@ -194,8 +184,8 @@ class WidgetInput(Input):
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: Any=None,
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
self.default = default
|
||||
self.socketless = socketless
|
||||
self.widget_type = widget_type
|
||||
@@ -217,13 +207,14 @@ class Output(_IO_V3):
|
||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
self.display_name = display_name if display_name else id
|
||||
self.tooltip = tooltip
|
||||
self.is_output_list = is_output_list
|
||||
|
||||
def as_dict(self):
|
||||
display_name = self.display_name if self.display_name else self.id
|
||||
return prune_dict({
|
||||
"display_name": self.display_name,
|
||||
"display_name": display_name,
|
||||
"tooltip": self.tooltip,
|
||||
"is_output_list": self.is_output_list,
|
||||
})
|
||||
@@ -251,8 +242,8 @@ class Boolean(ComfyTypeIO):
|
||||
'''Boolean input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: bool=None, label_on: str=None, label_off: str=None,
|
||||
socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.label_on = label_on
|
||||
self.label_off = label_off
|
||||
self.default: bool
|
||||
@@ -271,8 +262,8 @@ class Int(ComfyTypeIO):
|
||||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
@@ -297,8 +288,8 @@ class Float(ComfyTypeIO):
|
||||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
@@ -323,8 +314,8 @@ class String(ComfyTypeIO):
|
||||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
@@ -357,12 +348,14 @@ class Combo(ComfyTypeIO):
|
||||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
socketless: bool=None,
|
||||
extra_dict=None,
|
||||
raw_link: bool=None,
|
||||
):
|
||||
if isinstance(options, type) and issubclass(options, Enum):
|
||||
options = [v.value for v in options]
|
||||
if isinstance(default, Enum):
|
||||
default = default.value
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
self.multiselect = False
|
||||
self.options = options
|
||||
self.control_after_generate = control_after_generate
|
||||
@@ -386,10 +379,6 @@ class Combo(ComfyTypeIO):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.options = options if options is not None else []
|
||||
|
||||
@property
|
||||
def io_type(self):
|
||||
return self.options
|
||||
|
||||
@comfytype(io_type="COMBO")
|
||||
class MultiCombo(ComfyTypeI):
|
||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||
@@ -398,8 +387,8 @@ class MultiCombo(ComfyTypeI):
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
socketless: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
|
||||
self.multiselect = True
|
||||
self.placeholder = placeholder
|
||||
self.chip = chip
|
||||
@@ -432,9 +421,9 @@ class Webcam(ComfyTypeIO):
|
||||
Type = str
|
||||
def __init__(
|
||||
self, id: str, display_name: str=None, optional=False,
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
|
||||
|
||||
@comfytype(io_type="MASK")
|
||||
@@ -787,7 +776,7 @@ class MultiType:
|
||||
'''
|
||||
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
||||
'''
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
# if id is an Input, then use that Input with overridden values
|
||||
self.input_override = None
|
||||
if isinstance(id, Input):
|
||||
@@ -800,7 +789,7 @@ class MultiType:
|
||||
# if is a widget input, make sure widget_type is set appropriately
|
||||
if isinstance(self.input_override, WidgetInput):
|
||||
self.input_override.widget_type = self.input_override.get_io_type()
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
self._io_types = types
|
||||
|
||||
@property
|
||||
@@ -854,8 +843,8 @@ class MatchType(ComfyTypeIO):
|
||||
|
||||
class Input(Input):
|
||||
def __init__(self, id: str, template: MatchType.Template,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
@@ -866,6 +855,8 @@ class MatchType(ComfyTypeIO):
|
||||
class Output(Output):
|
||||
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
if not id and not display_name:
|
||||
display_name = "MATCHTYPE"
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.template = template
|
||||
|
||||
@@ -878,24 +869,30 @@ class DynamicInput(Input, ABC):
|
||||
'''
|
||||
Abstract class for dynamic input registration.
|
||||
'''
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return []
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class DynamicOutput(Output, ABC):
|
||||
'''
|
||||
Abstract class for dynamic output registration.
|
||||
'''
|
||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
pass
|
||||
|
||||
def get_dynamic(self) -> list[Output]:
|
||||
return []
|
||||
|
||||
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
|
||||
if prefix_list is None:
|
||||
prefix_list = []
|
||||
if id is not None:
|
||||
prefix_list = prefix_list + [id]
|
||||
return prefix_list
|
||||
|
||||
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
|
||||
assert not (prefix_list is None and id is None)
|
||||
if prefix_list is None:
|
||||
return id
|
||||
elif id is not None:
|
||||
prefix_list = prefix_list + [id]
|
||||
return ".".join(prefix_list)
|
||||
|
||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||
class Autogrow(ComfyTypeI):
|
||||
@@ -932,14 +929,6 @@ class Autogrow(ComfyTypeI):
|
||||
def validate(self):
|
||||
self.input.validate()
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
real_inputs = []
|
||||
for name, input in self.cached_inputs.items():
|
||||
if name in live_inputs:
|
||||
real_inputs.append(input)
|
||||
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
||||
|
||||
class TemplatePrefix(_AutogrowTemplate):
|
||||
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||
super().__init__(input)
|
||||
@@ -984,22 +973,45 @@ class Autogrow(ComfyTypeI):
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return self.template.get_all()
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + self.template.get_all()
|
||||
|
||||
def validate(self):
|
||||
self.template.validate()
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
||||
for inner_dict in d.values():
|
||||
if self.id in inner_dict:
|
||||
del inner_dict[self.id]
|
||||
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
# NOTE: purposely do not include self in out_dict; instead use only the template inputs
|
||||
# need to figure out names based on template type
|
||||
is_names = ("names" in value[1]["template"])
|
||||
is_prefix = ("prefix" in value[1]["template"])
|
||||
input = value[1]["template"]["input"]
|
||||
if is_names:
|
||||
min = value[1]["template"]["min"]
|
||||
names = value[1]["template"]["names"]
|
||||
max = len(names)
|
||||
elif is_prefix:
|
||||
prefix = value[1]["template"]["prefix"]
|
||||
min = value[1]["template"]["min"]
|
||||
max = value[1]["template"]["max"]
|
||||
names = [f"{prefix}{i}" for i in range(max)]
|
||||
# need to create a new input based on the contents of input
|
||||
template_input = None
|
||||
for _, dict_input in input.items():
|
||||
# for now, get just the first value from dict_input
|
||||
template_input = list(dict_input.values())[0]
|
||||
new_dict = {}
|
||||
for i, name in enumerate(names):
|
||||
expected_id = finalize_prefix(curr_prefix, name)
|
||||
if expected_id in live_inputs:
|
||||
# required
|
||||
if i < min:
|
||||
type_dict = new_dict.setdefault("required", {})
|
||||
# optional
|
||||
else:
|
||||
type_dict = new_dict.setdefault("optional", {})
|
||||
type_dict[name] = template_input
|
||||
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||
class DynamicCombo(ComfyTypeI):
|
||||
@@ -1022,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.options = options
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
# check if dynamic input's id is in live_inputs
|
||||
if self.id in live_inputs:
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
key = live_inputs[self.id]
|
||||
selected_option = None
|
||||
for option in self.options:
|
||||
if option.key == key:
|
||||
selected_option = option
|
||||
break
|
||||
if selected_option is not None:
|
||||
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return [input for option in self.options for input in option.inputs]
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + [input for option in self.options for input in option.inputs]
|
||||
|
||||
@@ -1053,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
|
||||
for input in option.inputs:
|
||||
input.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
finalized_id = finalize_prefix(curr_prefix)
|
||||
if finalized_id in live_inputs:
|
||||
key = live_inputs[finalized_id]
|
||||
selected_option = None
|
||||
# get options from dict
|
||||
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
|
||||
for option in options:
|
||||
if option["key"] == key:
|
||||
selected_option = option
|
||||
break
|
||||
if selected_option is not None:
|
||||
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
|
||||
# add self to inputs
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||
class DynamicSlot(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
@@ -1075,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
|
||||
self.force_input = True
|
||||
self.slot.force_input = True
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
if self.id in live_inputs:
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return [self.slot] + self.inputs
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + [self.slot] + self.inputs
|
||||
return [self.slot] + self.inputs
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
@@ -1099,17 +1103,53 @@ class DynamicSlot(ComfyTypeI):
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
|
||||
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
||||
dynamic = d.setdefault("dynamic_paths", {})
|
||||
if self is not None:
|
||||
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
||||
for i in inputs:
|
||||
if not isinstance(i, DynamicInput):
|
||||
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
finalized_id = finalize_prefix(curr_prefix)
|
||||
if finalized_id in live_inputs:
|
||||
inputs = value[1]["inputs"]
|
||||
parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
|
||||
# add self to inputs
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="IMAGECOMPARE")
|
||||
class ImageCompare(ComfyTypeI):
|
||||
Type = dict
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True):
|
||||
super().__init__(id, display_name, optional, tooltip, None, None, socketless)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
|
||||
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
|
||||
return DYNAMIC_INPUT_LOOKUP[io_type]
|
||||
|
||||
def setup_dynamic_input_funcs():
|
||||
# Autogrow.Input
|
||||
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
|
||||
# DynamicCombo.Input
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
class V3Data(TypedDict):
|
||||
hidden_inputs: dict[str, Any]
|
||||
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||
dynamic_paths: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
@@ -1145,6 +1185,10 @@ class HiddenHolder:
|
||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
|
||||
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||
|
||||
class Hidden(str, Enum):
|
||||
'''
|
||||
Enumerator for requesting hidden variables in nodes.
|
||||
@@ -1250,61 +1294,56 @@ class Schema:
|
||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||
'''
|
||||
nested_inputs: list[Input] = []
|
||||
if self.inputs is not None:
|
||||
for input in self.inputs:
|
||||
for input in self.inputs:
|
||||
if not isinstance(input, DynamicInput):
|
||||
nested_inputs.extend(input.get_all())
|
||||
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
||||
input_ids = [i.id for i in nested_inputs]
|
||||
output_ids = [o.id for o in self.outputs]
|
||||
input_set = set(input_ids)
|
||||
output_set = set(output_ids)
|
||||
issues = []
|
||||
issues: list[str] = []
|
||||
# verify ids are unique per list
|
||||
if len(input_set) != len(input_ids):
|
||||
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
||||
if len(output_set) != len(output_ids):
|
||||
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
||||
# verify ids are unique between lists
|
||||
intersection = input_set & output_set
|
||||
if len(intersection) > 0:
|
||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
||||
if len(issues) > 0:
|
||||
raise ValueError("\n".join(issues))
|
||||
# validate inputs and outputs
|
||||
if self.inputs is not None:
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
if self.outputs is not None:
|
||||
for output in self.outputs:
|
||||
output.validate()
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
for output in self.outputs:
|
||||
output.validate()
|
||||
|
||||
def finalize(self):
|
||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||
# ensure inputs, outputs, and hidden are lists
|
||||
if self.inputs is None:
|
||||
self.inputs = []
|
||||
if self.outputs is None:
|
||||
self.outputs = []
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
# if is an api_node, will need key-related hidden
|
||||
if self.is_api_node:
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
if Hidden.auth_token_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||
if Hidden.api_key_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.api_key_comfy_org)
|
||||
# if is an output_node, will need prompt and extra_pnginfo
|
||||
if self.is_output_node:
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
if Hidden.prompt not in self.hidden:
|
||||
self.hidden.append(Hidden.prompt)
|
||||
if Hidden.extra_pnginfo not in self.hidden:
|
||||
self.hidden.append(Hidden.extra_pnginfo)
|
||||
# give outputs without ids default ids
|
||||
if self.outputs is not None:
|
||||
for i, output in enumerate(self.outputs):
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
for i, output in enumerate(self.outputs):
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
|
||||
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
||||
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||
# get V1 inputs
|
||||
input = create_input_dict_v1(self.inputs, live_inputs)
|
||||
input = create_input_dict_v1(self.inputs)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
@@ -1384,33 +1423,54 @@ class Schema:
|
||||
)
|
||||
return info
|
||||
|
||||
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||
out_dict = {
|
||||
"required": {},
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
hidden = d.pop("hidden", None)
|
||||
parse_class_inputs(out_dict, live_inputs, d)
|
||||
if hidden is not None and include_hidden:
|
||||
out_dict["hidden"] = hidden
|
||||
v3_data = {}
|
||||
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||
if dynamic_paths is not None:
|
||||
v3_data["dynamic_paths"] = dynamic_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
for input_type, inner_d in curr_dict.items():
|
||||
for id, value in inner_d.items():
|
||||
io_type = value[0]
|
||||
if io_type in DYNAMIC_INPUT_LOOKUP:
|
||||
# dynamic inputs need to be handled with lookup functions
|
||||
dynamic_input_func = get_dynamic_input_func(io_type)
|
||||
new_prefix = handle_prefix(curr_prefix, id)
|
||||
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
|
||||
else:
|
||||
# non-dynamic inputs get directly transferred
|
||||
finalized_id = finalize_prefix(curr_prefix, id)
|
||||
out_dict[input_type][finalized_id] = value
|
||||
if curr_prefix:
|
||||
out_dict["dynamic_paths"][finalized_id] = finalized_id
|
||||
|
||||
def create_input_dict_v1(inputs: list[Input]) -> dict:
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
add_to_input_dict_v1(input, inputs, live_inputs)
|
||||
for i in inputs:
|
||||
add_to_dict_v1(i, input)
|
||||
return input
|
||||
|
||||
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
||||
for i in inputs:
|
||||
if isinstance(i, DynamicInput):
|
||||
add_to_dict_v1(i, d)
|
||||
if live_inputs is not None:
|
||||
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||
else:
|
||||
add_to_dict_v1(i, d)
|
||||
|
||||
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
||||
def add_to_dict_v1(i: Input, d: dict):
|
||||
key = "optional" if i.optional else "required"
|
||||
as_dict = i.as_dict()
|
||||
# for v1, we don't want to include the optional key
|
||||
as_dict.pop("optional", None)
|
||||
if dynamic_dict is None:
|
||||
value = (i.get_io_type(), as_dict)
|
||||
else:
|
||||
value = (i.get_io_type(), as_dict, dynamic_dict)
|
||||
d.setdefault(key, {})[i.id] = value
|
||||
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
@@ -1422,6 +1482,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
values = values.copy()
|
||||
result = {}
|
||||
|
||||
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||
|
||||
for key, path in paths.items():
|
||||
parts = path.split(".")
|
||||
current = result
|
||||
@@ -1430,7 +1492,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
is_last = (i == len(parts) - 1)
|
||||
|
||||
if is_last:
|
||||
current[p] = values.pop(key, None)
|
||||
value = values.pop(key, None)
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
else:
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
@@ -1445,7 +1510,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
SCHEMA = None
|
||||
|
||||
# filled in during execution
|
||||
resources: Resources = None
|
||||
hidden: HiddenHolder = None
|
||||
|
||||
@classmethod
|
||||
@@ -1492,7 +1556,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
return [name for name in kwargs if kwargs[name] is None]
|
||||
|
||||
def __init__(self):
|
||||
self.local_resources: ResourcesLocal = None
|
||||
self.__class__.VALIDATE_CLASS()
|
||||
|
||||
@classmethod
|
||||
@@ -1560,7 +1623,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@@ -1677,19 +1740,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
schema = cls.FINALIZE_SCHEMA()
|
||||
info = schema.get_v1_info(cls, live_inputs)
|
||||
input = info.input
|
||||
if not include_hidden:
|
||||
input.pop("hidden", None)
|
||||
if return_schema:
|
||||
v3_data: V3Data = {}
|
||||
dynamic = input.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
return input, schema, v3_data
|
||||
return input
|
||||
info = schema.get_v1_info(cls)
|
||||
return info.input
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
@@ -1808,7 +1862,7 @@ class NodeOutput(_NodeOutputInternal):
|
||||
return self.args if len(self.args) > 0 else None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput":
|
||||
def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
|
||||
args = ()
|
||||
ui = None
|
||||
expand = None
|
||||
@@ -1903,8 +1957,8 @@ __all__ = [
|
||||
"Tracks",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
# "DynamicCombo",
|
||||
# "Autogrow",
|
||||
"DynamicCombo",
|
||||
"Autogrow",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
@@ -1916,4 +1970,5 @@ __all__ = [
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
"ImageCompare",
|
||||
]
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class ResourceKey(ABC):
|
||||
Type = Any
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
class TorchDictFolderFilename(ResourceKey):
|
||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||
Type = dict[str, torch.Tensor]
|
||||
def __init__(self, folder_name: str, file_name: str):
|
||||
self.folder_name = folder_name
|
||||
self.file_name = file_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.folder_name, self.file_name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TorchDictFolderFilename):
|
||||
return False
|
||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.folder_name} -> {self.file_name}"
|
||||
|
||||
class Resources(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
pass
|
||||
|
||||
class ResourcesLocal(Resources):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.local_resources: dict[ResourceKey, Any] = {}
|
||||
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
cached = self.local_resources.get(key, None)
|
||||
if cached is not None:
|
||||
logging.info(f"Using cached resource '{key}'")
|
||||
return cached
|
||||
logging.info(f"Loading resource '{key}'")
|
||||
to_return = None
|
||||
if isinstance(key, TorchDictFolderFilename):
|
||||
if default is ...:
|
||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||
else:
|
||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||
if full_path is not None:
|
||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||
|
||||
if to_return is not None:
|
||||
self.local_resources[key] = to_return
|
||||
return to_return
|
||||
if default is not ...:
|
||||
return default
|
||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||
|
||||
|
||||
class _RESOURCES:
|
||||
ResourceKey = ResourceKey
|
||||
TorchDictFolderFilename = TorchDictFolderFilename
|
||||
Resources = Resources
|
||||
ResourcesLocal = ResourcesLocal
|
||||
41
comfy_api_nodes/apis/vidu.py
Normal file
41
comfy_api_nodes/apis/vidu.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SubjectReference(BaseModel):
|
||||
id: str = Field(...)
|
||||
images: list[str] = Field(...)
|
||||
|
||||
|
||||
class TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(..., max_length=2000)
|
||||
duration: int = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
aspect_ratio: str | None = Field(None)
|
||||
resolution: str | None = Field(None)
|
||||
movement_amplitude: str | None = Field(None)
|
||||
images: list[str] | None = Field(None, description="Base64 encoded string or image URL")
|
||||
subjects: list[SubjectReference] | None = Field(None)
|
||||
bgm: bool | None = Field(None)
|
||||
audio: bool | None = Field(None)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: int | None = Field(None, description="Error code")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
id: str = Field(..., description="Creation id")
|
||||
url: str = Field(..., description="The URL of the generated results, valid for one hour")
|
||||
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: str = Field(...)
|
||||
err_code: str | None = Field(None)
|
||||
progress: float | None = Field(None)
|
||||
credits: int | None = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
@@ -567,7 +567,7 @@ async def execute_lipsync(
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = await upload_audio_to_comfyapi(
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg"
|
||||
)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
@@ -807,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -826,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
@@ -837,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@@ -872,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
optional=True,
|
||||
tooltip="Up to 6 additional reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -893,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image | None = None,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@@ -936,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
prompt=prompt,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@@ -964,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
"reference_images",
|
||||
tooltip="Up to 7 reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -984,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_images: Input.Image,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@@ -1005,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@@ -1036,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -1058,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@@ -1090,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
duration=str(duration),
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@@ -1119,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -1139,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@@ -1171,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
duration=None,
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
@@ -2,7 +2,6 @@ import builtins
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
@@ -138,7 +137,7 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str = "",
|
||||
subject_detection: str = "All",
|
||||
face_enhancement: bool = True,
|
||||
@@ -153,7 +152,9 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
|
||||
download_url = await upload_images_to_comfyapi(
|
||||
cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
|
||||
|
||||
@@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
geometry_quality=geometry_quality,
|
||||
auto_size=True,
|
||||
quad=quad,
|
||||
@@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
texture_alignment=texture_alignment,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
auto_size=True,
|
||||
quad=quad,
|
||||
),
|
||||
@@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
texture_quality=texture_quality,
|
||||
geometry_quality=geometry_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
quad=quad,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.vidu import (
|
||||
SubjectReference,
|
||||
TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskResult,
|
||||
TaskStatusResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
@@ -17,6 +18,7 @@ from comfy_api_nodes.util import (
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_images_aspect_ratio_closeness,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
@@ -25,98 +27,33 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
|
||||
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
|
||||
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class VideoModelName(str, Enum):
|
||||
vidu_q1 = "viduq1"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
r_16_9 = "16:9"
|
||||
r_9_16 = "9:16"
|
||||
r_1_1 = "1:1"
|
||||
|
||||
|
||||
class Resolution(str, Enum):
|
||||
r_1080p = "1080p"
|
||||
|
||||
|
||||
class MovementAmplitude(str, Enum):
|
||||
auto = "auto"
|
||||
small = "small"
|
||||
medium = "medium"
|
||||
large = "large"
|
||||
|
||||
|
||||
class TaskCreationRequest(BaseModel):
|
||||
model: VideoModelName = VideoModelName.vidu_q1
|
||||
prompt: Optional[str] = Field(None, max_length=1500)
|
||||
duration: Optional[Literal[5]] = 5
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
|
||||
resolution: Optional[Resolution] = Resolution.r_1080p
|
||||
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
|
||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: Optional[int] = Field(None, description="Error code")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
id: str = Field(..., description="Creation id")
|
||||
url: str = Field(..., description="The URL of the generated results, valid for one hour")
|
||||
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: str = Field(...)
|
||||
err_code: Optional[str] = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
if response.creations:
|
||||
return response.creations[0].url
|
||||
return None
|
||||
|
||||
|
||||
def get_video_from_response(response) -> TaskResult:
|
||||
if not response.creations:
|
||||
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
|
||||
logging.info(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
|
||||
return response.creations[0]
|
||||
|
||||
|
||||
async def execute_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
vidu_endpoint: str,
|
||||
payload: TaskCreationRequest,
|
||||
estimated_duration: int,
|
||||
) -> R:
|
||||
response = await sync_op(
|
||||
) -> list[TaskResult]:
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=payload,
|
||||
)
|
||||
if response.state == "failed":
|
||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return await poll_op(
|
||||
if task_creation_response.state == "failed":
|
||||
raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
if not response.creations:
|
||||
raise RuntimeError(
|
||||
f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
|
||||
)
|
||||
return response.creations
|
||||
|
||||
|
||||
class ViduTextToVideoNode(IO.ComfyNode):
|
||||
@@ -127,14 +64,9 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from text prompt",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@@ -163,22 +95,19 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=AspectRatio,
|
||||
default=AspectRatio.r_16_9,
|
||||
options=["16:9", "9:16", "1:1"],
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=Resolution,
|
||||
default=Resolution.r_1080p,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=MovementAmplitude,
|
||||
default=MovementAmplitude.auto,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@@ -208,7 +137,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@@ -216,8 +145,8 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduImageToVideoNode(IO.ComfyNode):
|
||||
@@ -230,12 +159,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video",
|
||||
@@ -270,15 +194,13 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=Resolution,
|
||||
default=Resolution.r_1080p,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=MovementAmplitude,
|
||||
default=MovementAmplitude.auto.value,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@@ -298,7 +220,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
@@ -309,7 +231,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@@ -322,8 +244,8 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
)
|
||||
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
@@ -334,14 +256,9 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from multiple images and prompt",
|
||||
description="Generate video from multiple images and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
||||
@@ -374,22 +291,19 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=AspectRatio,
|
||||
default=AspectRatio.r_16_9,
|
||||
options=["16:9", "9:16", "1:1"],
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@@ -409,7 +323,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
@@ -426,7 +340,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@@ -440,8 +354,8 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
max_images=7,
|
||||
mime_type="image/png",
|
||||
)
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
@@ -454,12 +368,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="Start frame",
|
||||
@@ -497,15 +406,13 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
options=["1080p"],
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
@@ -525,8 +432,8 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
first_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
@@ -535,7 +442,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
@@ -546,8 +453,391 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
]
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2TextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2TextToVideoNode",
|
||||
display_name="Vidu2 Text-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation, with a maximum length of 2000 characters.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Boolean.Input(
|
||||
"background_music",
|
||||
default=False,
|
||||
tooltip="Whether to add background music to the generated video.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
background_music: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2000)
|
||||
results = await execute_task(
|
||||
cls,
|
||||
VIDU_TEXT_TO_VIDEO,
|
||||
TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
bgm=background_music,
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2ImageToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ImageToVideoNode",
|
||||
display_name="Vidu2 Image-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="An optional text prompt for video generation (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) > 1:
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
validate_string(prompt, max_length=2000)
|
||||
results = await execute_task(
|
||||
cls,
|
||||
VIDU_IMAGE_TO_VIDEO,
|
||||
TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ReferenceVideoNode",
|
||||
display_name="Vidu2 Reference-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from multiple reference images and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
IO.Autogrow.Input(
|
||||
"subjects",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("reference_images"),
|
||||
names=["subject1", "subject2", "subject3"],
|
||||
min=1,
|
||||
),
|
||||
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
|
||||
"Reference them in prompts via @subject{subject_id}.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="When enabled, the video will include generated speech and background music "
|
||||
"based on the prompt.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled video will contain generated speech and background music based on the prompt.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
|
||||
IO.Combo.Input("resolution", options=["720p"]),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
subjects: IO.Autogrow.Type,
|
||||
prompt: str,
|
||||
audio: bool,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2000)
|
||||
total_images = 0
|
||||
for i in subjects:
|
||||
if get_number_of_images(subjects[i]) > 3:
|
||||
raise ValueError("Maximum number of images per subject is 3.")
|
||||
for im in subjects[i]:
|
||||
total_images += 1
|
||||
validate_image_aspect_ratio(im, (1, 4), (4, 1))
|
||||
validate_image_dimensions(im, min_width=128, min_height=128)
|
||||
if total_images > 7:
|
||||
raise ValueError("Too many reference images; the maximum allowed is 7.")
|
||||
subjects_param: list[SubjectReference] = []
|
||||
for i in subjects:
|
||||
subjects_param.append(
|
||||
SubjectReference(
|
||||
id=i,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
subjects[i],
|
||||
max_images=3,
|
||||
mime_type="image/png",
|
||||
wait_label=f"Uploading reference images for {i}",
|
||||
),
|
||||
),
|
||||
)
|
||||
payload = TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
audio=audio,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
subjects=subjects_param,
|
||||
)
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu2StartEndToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2StartEndToVideoNode",
|
||||
display_name="Vidu2 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input("end_frame"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt description (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=2,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=["auto", "small", "medium", "large"],
|
||||
tooltip="The movement amplitude of objects in the frame.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2000)
|
||||
if get_number_of_images(first_frame) > 1:
|
||||
raise ValueError("Only one input image is allowed for `first_frame`.")
|
||||
if get_number_of_images(end_frame) > 1:
|
||||
raise ValueError("Only one input image is allowed for `end_frame`.")
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
images=[
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
],
|
||||
)
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@@ -558,6 +848,10 @@ class ViduExtension(ComfyExtension):
|
||||
ViduImageToVideoNode,
|
||||
ViduReferenceVideoNode,
|
||||
ViduStartEndToVideoNode,
|
||||
Vidu2TextToVideoNode,
|
||||
Vidu2ImageToVideoNode,
|
||||
Vidu2ReferenceVideoNode,
|
||||
Vidu2StartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_video_to_comfyapi,
|
||||
validate_audio_duration,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
@@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel):
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Reference2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
reference_video_urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class Txt2ImageParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
@@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel):
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Reference2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
shot_type: str = Field("single")
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2ImageInputField = Field(...)
|
||||
@@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
||||
parameters: Image2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Reference2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Reference2VideoInputField = Field(...)
|
||||
parameters: Reference2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class TaskCreationOutputField(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
@@ -721,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanReferenceVideoApi(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="WanReferenceVideoApi",
|
||||
display_name="Wan Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Use the character and voice from input videos, combined with a prompt, "
|
||||
"to generate a new video that maintains character consistency.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["wan2.6-r2v"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese. "
|
||||
"Use identifiers such as `character1` and `character2` to refer to the reference characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative prompt describing what to avoid.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("reference_video"),
|
||||
names=["character1", "character2", "character3"],
|
||||
min=1,
|
||||
),
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=[
|
||||
"720p: 1:1 (960x960)",
|
||||
"720p: 16:9 (1280x720)",
|
||||
"720p: 9:16 (720x1280)",
|
||||
"720p: 4:3 (1088x832)",
|
||||
"720p: 3:4 (832x1088)",
|
||||
"1080p: 1:1 (1440x1440)",
|
||||
"1080p: 16:9 (1920x1080)",
|
||||
"1080p: 9:16 (1080x1920)",
|
||||
"1080p: 4:3 (1632x1248)",
|
||||
"1080p: 3:4 (1248x1632)",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
step=5,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"shot_type",
|
||||
options=["single", "multi"],
|
||||
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||
"single continuous shot or multiple shots with cuts.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
reference_videos: IO.Autogrow.Type,
|
||||
size: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
shot_type: str,
|
||||
watermark: bool,
|
||||
):
|
||||
reference_video_urls = []
|
||||
for i in reference_videos:
|
||||
validate_video_duration(reference_videos[i], min_duration=2, max_duration=30)
|
||||
for i in reference_videos:
|
||||
reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i]))
|
||||
width, height = RES_IN_PARENS.search(size).groups()
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Reference2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Reference2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls
|
||||
),
|
||||
parameters=Reference2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
shot_type=shot_type,
|
||||
watermark=watermark,
|
||||
seed=seed,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=VideoTaskStatusResponse,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
poll_interval=6,
|
||||
max_poll_attempts=280,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -729,6 +888,7 @@ class WanApiExtension(ComfyExtension):
|
||||
WanImageToImageApi,
|
||||
WanTextToVideoApi,
|
||||
WanImageToVideoApi,
|
||||
WanReferenceVideoApi,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
from .common_exceptions import ProcessingInterrupted
|
||||
|
||||
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
|
||||
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
|
||||
|
||||
|
||||
def is_processing_interrupted() -> bool:
|
||||
"""Return True if user/runtime requested interruption."""
|
||||
@@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
|
||||
|
||||
def to_aiohttp_url(url: str) -> URL:
|
||||
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
|
||||
escape and no malformed '%' sequences) and contains no raw whitespace/control
|
||||
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
|
||||
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
|
||||
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
|
||||
# Avoid encoded=True if URL contains raw whitespace/control chars
|
||||
return URL(url)
|
||||
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
|
||||
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
|
||||
return URL(url, encoded=True)
|
||||
return URL(url)
|
||||
|
||||
@@ -55,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: str | None = None,
|
||||
*,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
@@ -75,7 +75,7 @@ def tensor_to_bytesio(
|
||||
|
||||
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
return img_binary
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from ._helpers import (
|
||||
get_auth_header,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
to_aiohttp_url,
|
||||
)
|
||||
from .client import _diagnose_connectivity
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
@@ -94,7 +95,7 @@ async def download_url_to_bytesio(
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
|
||||
req_task = asyncio.create_task(session.get(url, headers=headers))
|
||||
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
|
||||
@@ -49,6 +49,7 @@ async def upload_images_to_comfyapi(
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
show_batch_index: bool = True,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
@@ -63,7 +64,7 @@ async def upload_images_to_comfyapi(
|
||||
|
||||
for idx in range(num_to_upload):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
|
||||
|
||||
effective_label = wait_label
|
||||
if wait_label and show_batch_index and num_to_upload > 1:
|
||||
@@ -81,7 +82,6 @@ async def upload_audio_to_comfyapi(
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
@@ -91,7 +91,7 @@ async def upload_audio_to_comfyapi(
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
@@ -119,7 +119,7 @@ async def upload_video_to_comfyapi(
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
filename = f"{uuid.uuid4()}.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = BytesIO()
|
||||
|
||||
@@ -97,6 +97,11 @@ def get_input_info(
|
||||
extra_info = input_info[1]
|
||||
else:
|
||||
extra_info = {}
|
||||
# if input_type is a list, it is a Combo defined in outdated format; convert it.
|
||||
# NOTE: uncomment this when we are confident old format going away won't cause too much trouble.
|
||||
# if isinstance(input_type, list):
|
||||
# extra_info["options"] = input_type
|
||||
# input_type = IO.Combo.io_type
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
class TopologicalSort:
|
||||
@@ -202,15 +207,15 @@ class ExecutionList(TopologicalSort):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if not from_node_id in self.execution_cache_listeners:
|
||||
if from_node_id not in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
|
||||
def get_cache(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
return None
|
||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||
if value is None:
|
||||
|
||||
@@ -14,8 +14,9 @@ class JobStatus:
|
||||
IN_PROGRESS = 'in_progress'
|
||||
COMPLETED = 'completed'
|
||||
FAILED = 'failed'
|
||||
CANCELLED = 'cancelled'
|
||||
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
@@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
|
||||
status_info = history_item.get('status', {})
|
||||
status_str = status_info.get('status_str') if status_info else None
|
||||
if status_str == 'success':
|
||||
status = JobStatus.COMPLETED
|
||||
elif status_str == 'error':
|
||||
status = JobStatus.FAILED
|
||||
else:
|
||||
status = JobStatus.COMPLETED
|
||||
|
||||
outputs = history_item.get('outputs', {})
|
||||
outputs_count, preview_output = get_outputs_summary(outputs)
|
||||
@@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
execution_error = None
|
||||
execution_start_time = None
|
||||
execution_end_time = None
|
||||
was_interrupted = False
|
||||
if status_info:
|
||||
messages = status_info.get('messages', [])
|
||||
for entry in messages:
|
||||
@@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
execution_end_time = event_data.get('timestamp')
|
||||
if event_name == 'execution_error':
|
||||
execution_error = event_data
|
||||
elif event_name == 'execution_interrupted':
|
||||
was_interrupted = True
|
||||
|
||||
if status_str == 'success':
|
||||
status = JobStatus.COMPLETED
|
||||
elif status_str == 'error':
|
||||
status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED
|
||||
else:
|
||||
status = JobStatus.COMPLETED
|
||||
|
||||
job = prune_dict({
|
||||
'id': prompt_id,
|
||||
@@ -268,13 +273,13 @@ def get_all_jobs(
|
||||
for item in queued:
|
||||
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
|
||||
|
||||
include_completed = JobStatus.COMPLETED in status_filter
|
||||
include_failed = JobStatus.FAILED in status_filter
|
||||
if include_completed or include_failed:
|
||||
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
|
||||
requested_history_statuses = history_statuses & set(status_filter)
|
||||
if requested_history_statuses:
|
||||
for prompt_id, history_item in history.items():
|
||||
is_failed = history_item.get('status', {}).get('status_str') == 'error'
|
||||
if (is_failed and include_failed) or (not is_failed and include_completed):
|
||||
jobs.append(normalize_history_item(prompt_id, history_item))
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
|
||||
if workflow_id:
|
||||
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
||||
|
||||
@@ -21,14 +21,24 @@ def validate_node_input(
|
||||
"""
|
||||
# If the types are exactly the same, we can return immediately
|
||||
# Use pre-union behaviour: inverse of `__ne__`
|
||||
# NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class.
|
||||
if not received_type != input_type:
|
||||
return True
|
||||
|
||||
# If one of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||
if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type:
|
||||
return True
|
||||
|
||||
# If the received type or input_type is a MatchType, we can return True immediately;
|
||||
# validation for this is handled by the frontend
|
||||
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
||||
return True
|
||||
|
||||
# This accounts for some custom nodes that output lists of options as the type;
|
||||
# if we ever want to break them on purpose, this can be removed
|
||||
if isinstance(received_type, list) and input_type == IO.Combo.io_type:
|
||||
return True
|
||||
|
||||
# Not equal, and not strings
|
||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||
return False
|
||||
@@ -37,6 +47,10 @@ def validate_node_input(
|
||||
received_types = set(t.strip() for t in received_type.split(","))
|
||||
input_types = set(t.strip() for t in input_type.split(","))
|
||||
|
||||
# If any of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||
if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types:
|
||||
return True
|
||||
|
||||
if strict:
|
||||
# In strict mode, all received types must be in the input types
|
||||
return received_types.issubset(input_types)
|
||||
|
||||
@@ -55,7 +55,8 @@ class APG(io.ComfyNode):
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
||||
if len(args["conds_out"]) == 1:
|
||||
return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
|
||||
@@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
@@ -399,6 +399,58 @@ class SplitAudioChannels(IO.ComfyNode):
|
||||
|
||||
separate = execute # TODO: remove
|
||||
|
||||
class JoinAudioChannels(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="JoinAudioChannels",
|
||||
display_name="Join Audio Channels",
|
||||
description="Joins left and right mono audio channels into a stereo audio.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio_left"),
|
||||
IO.Audio.Input("audio_right"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(display_name="audio"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
|
||||
waveform_left = audio_left["waveform"]
|
||||
sample_rate_left = audio_left["sample_rate"]
|
||||
waveform_right = audio_right["waveform"]
|
||||
sample_rate_right = audio_right["sample_rate"]
|
||||
|
||||
if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
|
||||
raise ValueError("AudioJoin: Both input audios must be mono.")
|
||||
|
||||
# Handle different sample rates by resampling to the higher rate
|
||||
waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
|
||||
waveform_left, sample_rate_left, waveform_right, sample_rate_right
|
||||
)
|
||||
|
||||
# Handle different lengths by trimming to the shorter length
|
||||
length_left = waveform_left.shape[-1]
|
||||
length_right = waveform_right.shape[-1]
|
||||
|
||||
if length_left != length_right:
|
||||
min_length = min(length_left, length_right)
|
||||
if length_left > min_length:
|
||||
logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
|
||||
waveform_left = waveform_left[..., :min_length]
|
||||
if length_right > min_length:
|
||||
logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
|
||||
waveform_right = waveform_right[..., :min_length]
|
||||
|
||||
# Join the channels into stereo
|
||||
left_channel = waveform_left[..., 0:1, :]
|
||||
right_channel = waveform_right[..., 0:1, :]
|
||||
stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
|
||||
|
||||
return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
|
||||
|
||||
|
||||
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
||||
if sample_rate_1 != sample_rate_2:
|
||||
@@ -616,6 +668,7 @@ class AudioExtension(ComfyExtension):
|
||||
RecordAudio,
|
||||
TrimAudioDuration,
|
||||
SplitAudioChannels,
|
||||
JoinAudioChannels,
|
||||
AudioConcat,
|
||||
AudioMerge,
|
||||
AudioAdjustVolume,
|
||||
|
||||
@@ -5,7 +5,9 @@ import comfy.model_management
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
|
||||
import folder_paths
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
|
||||
|
||||
if "blocks.0.block.0.conv.weight" in sd:
|
||||
config = {
|
||||
@@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
"global_residual": False,
|
||||
}
|
||||
model_type = "720p"
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||
config = {
|
||||
@@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||
}
|
||||
model_type = "1080p"
|
||||
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
|
||||
config = json.loads(metadata["config"])
|
||||
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
|
||||
model.load_state_dict(sd)
|
||||
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
53
comfy_extras/nodes_image_compare.py
Normal file
53
comfy_extras/nodes_image_compare.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import nodes
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
|
||||
|
||||
class ImageCompare(IO.ComfyNode):
|
||||
"""Compares two images with a slider interface."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompare",
|
||||
display_name="Image Compare",
|
||||
description="Compares two images side by side with a slider.",
|
||||
category="image",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image_a", optional=True),
|
||||
IO.Image.Input("image_b", optional=True),
|
||||
IO.ImageCompare.Input("compare_view"),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput:
|
||||
result = {"a_images": [], "b_images": []}
|
||||
|
||||
preview_node = nodes.PreviewImage()
|
||||
|
||||
if image_a is not None and len(image_a) > 0:
|
||||
saved = preview_node.save_images(image_a, "comfy.compare.a")
|
||||
result["a_images"] = saved["ui"]["images"]
|
||||
|
||||
if image_b is not None and len(image_b) > 0:
|
||||
saved = preview_node.save_images(image_b, "comfy.compare.b")
|
||||
result["b_images"] = saved["ui"]["images"]
|
||||
|
||||
return IO.NodeOutput(ui=result)
|
||||
|
||||
|
||||
class ImageCompareExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ImageCompare,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ImageCompareExtension:
|
||||
return ImageCompareExtension()
|
||||
@@ -255,6 +255,7 @@ class LatentBatch(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentBatch",
|
||||
category="latent/batch",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypedDict
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.latest import _io
|
||||
|
||||
# sentinel for missing inputs
|
||||
MISSING = object()
|
||||
|
||||
|
||||
class SwitchNode(io.ComfyNode):
|
||||
@@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode):
|
||||
display_name="Switch",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
io.MatchType.Input("on_false", template=template, lazy=True),
|
||||
io.MatchType.Input("on_true", template=template, lazy=True),
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(template=template, display_name="output"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, switch, on_false=None, on_true=None):
|
||||
if switch and on_true is None:
|
||||
return ["on_true"]
|
||||
if not switch and on_false is None:
|
||||
return ["on_false"]
|
||||
|
||||
@classmethod
|
||||
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
|
||||
return io.NodeOutput(on_true if switch else on_false)
|
||||
|
||||
|
||||
class SoftSwitchNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("switch")
|
||||
return io.Schema(
|
||||
node_id="ComfySoftSwitchNode",
|
||||
display_name="Soft Switch",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||
@@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, switch, on_false=..., on_true=...):
|
||||
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||
# We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
||||
|
||||
# One of the inputs may be missing, in which case we need to evaluate the other input
|
||||
if on_false is ...:
|
||||
if on_false is MISSING:
|
||||
return ["on_true"]
|
||||
if on_true is ...:
|
||||
if on_true is MISSING:
|
||||
return ["on_false"]
|
||||
# Normal lazy switch operation
|
||||
if switch and on_true is None:
|
||||
@@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode):
|
||||
return ["on_false"]
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, switch, on_false=..., on_true=...):
|
||||
def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||
# This check happens before check_lazy_status(), so we can eliminate the case where
|
||||
# both inputs are missing.
|
||||
if on_false is ... and on_true is ...:
|
||||
if on_false is MISSING and on_true is MISSING:
|
||||
return "At least one of on_false or on_true must be connected to Switch node"
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
|
||||
if on_true is ...:
|
||||
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
|
||||
if on_true is MISSING:
|
||||
return io.NodeOutput(on_false)
|
||||
if on_false is ...:
|
||||
if on_false is MISSING:
|
||||
return io.NodeOutput(on_true)
|
||||
return io.NodeOutput(on_true if switch else on_false)
|
||||
|
||||
|
||||
class CustomComboNode(io.ComfyNode):
|
||||
"""
|
||||
Frontend node that allows user to write their own options for a combo.
|
||||
This is here to make sure the node has a backend-representation to avoid some annoyances.
|
||||
"""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CustomCombo",
|
||||
display_name="Custom Combo",
|
||||
category="utils",
|
||||
is_experimental=True,
|
||||
inputs=[io.Combo.Input("choice", options=[])],
|
||||
outputs=[io.String.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
|
||||
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
|
||||
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
|
||||
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice)
|
||||
|
||||
|
||||
class DCTestNode(io.ComfyNode):
|
||||
class DCValues(TypedDict):
|
||||
combo: str
|
||||
@@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode):
|
||||
display_name="DCTest",
|
||||
category="logic",
|
||||
is_output_node=True,
|
||||
inputs=[_io.DynamicCombo.Input("combo", options=[
|
||||
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||
_io.DynamicCombo.Option("option4", [
|
||||
_io.DynamicCombo.Input("subcombo", options=[
|
||||
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||
io.DynamicCombo.Option("option4", [
|
||||
io.DynamicCombo.Input("subcombo", options=[
|
||||
io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||
io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||
])
|
||||
])]
|
||||
)],
|
||||
@@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
||||
combined = ",".join([str(x) for x in vals])
|
||||
return io.NodeOutput(combined)
|
||||
|
||||
class ComboOutputTestNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ComboOptionTestNode",
|
||||
display_name="ComboOptionTest",
|
||||
category="logic",
|
||||
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(combo, combo2)
|
||||
|
||||
class ConvertStringToComboNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertStringToComboNode",
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
outputs=[io.Combo.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string: str) -> io.NodeOutput:
|
||||
return io.NodeOutput(string)
|
||||
|
||||
class InvertBooleanNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InvertBooleanNode",
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, boolean: bool) -> io.NodeOutput:
|
||||
return io.NodeOutput(not boolean)
|
||||
|
||||
class LogicExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
# SwitchNode,
|
||||
SwitchNode,
|
||||
CustomComboNode,
|
||||
# SoftSwitchNode,
|
||||
# ConvertStringToComboNode,
|
||||
# DCTestNode,
|
||||
# AutogrowNamesTestNode,
|
||||
# AutogrowPrefixTestNode,
|
||||
# ComboOutputTestNode,
|
||||
# InvertBooleanNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> LogicExtension:
|
||||
|
||||
@@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVImgToVideoInplace",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
|
||||
if bypass:
|
||||
return (latent,)
|
||||
|
||||
samples = latent["samples"]
|
||||
_, height_scale_factor, width_scale_factor = (
|
||||
vae.downscale_index_formula
|
||||
)
|
||||
|
||||
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||
width = latent_width * width_scale_factor
|
||||
height = latent_height * height_scale_factor
|
||||
|
||||
if image.shape[1] != height or image.shape[2] != width:
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
else:
|
||||
pixels = image
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
samples[:, :, :t.shape[2]] = t
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch, 1, latent_frames, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=samples.device,
|
||||
)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def conditioning_get_any_value(conditioning, key, default=None):
|
||||
for t in conditioning:
|
||||
if key in t[1]:
|
||||
@@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
|
||||
class LTXVAddGuide(io.ComfyNode):
|
||||
NUM_PREFIX_FRAMES = 2
|
||||
PATCHIFIER = SymmetricPatchifier(1)
|
||||
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
@classmethod
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
|
||||
_, latent_idx = cls.get_latent_index(
|
||||
cond=positive,
|
||||
latent_length=latent_image.shape[2],
|
||||
guide_length=guiding_latent.shape[2],
|
||||
frame_idx=frame_idx,
|
||||
scale_factors=scale_factors,
|
||||
)
|
||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
|
||||
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
|
||||
raise ValueError("Adding guide to a combined AV latent is not supported.")
|
||||
|
||||
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
if guide_mask is not None:
|
||||
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
|
||||
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
|
||||
|
||||
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
|
||||
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
|
||||
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
|
||||
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
|
||||
mask = guide_mask - strength
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
# This solves audio video combined latent case where latent_image has audio latent concatenated
|
||||
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
|
||||
if latent_image.shape[1] > guiding_latent.shape[1]:
|
||||
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
|
||||
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
|
||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||
noise_mask = torch.cat([noise_mask, mask], dim=2)
|
||||
return positive, negative, latent_image, noise_mask
|
||||
@@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
positive,
|
||||
negative,
|
||||
frame_idx,
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t[:, :, :num_prefix_frames],
|
||||
t,
|
||||
strength,
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
latent_idx += num_prefix_frames
|
||||
|
||||
t = t[:, :, num_prefix_frames:]
|
||||
if t.shape[2] == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
latent_image, noise_mask = cls.replace_latent_frames(
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t,
|
||||
latent_idx,
|
||||
strength,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
@@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode):
|
||||
|
||||
preprocess = execute # TODO: remove
|
||||
|
||||
|
||||
import comfy.nested_tensor
|
||||
class LTXVConcatAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVConcatAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Latent.Input("video_latent"),
|
||||
io.Latent.Input("audio_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
|
||||
output = {}
|
||||
output.update(video_latent)
|
||||
output.update(audio_latent)
|
||||
video_noise_mask = video_latent.get("noise_mask", None)
|
||||
audio_noise_mask = audio_latent.get("noise_mask", None)
|
||||
|
||||
if video_noise_mask is not None or audio_noise_mask is not None:
|
||||
if video_noise_mask is None:
|
||||
video_noise_mask = torch.ones_like(video_latent["samples"])
|
||||
if audio_noise_mask is None:
|
||||
audio_noise_mask = torch.ones_like(audio_latent["samples"])
|
||||
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
|
||||
|
||||
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
|
||||
|
||||
return io.NodeOutput(output)
|
||||
|
||||
|
||||
class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVSeparateAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
description="LTXV Separate AV Latent",
|
||||
inputs=[
|
||||
io.Latent.Input("av_latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="video_latent"),
|
||||
io.Latent.Output(display_name="audio_latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, av_latent) -> io.NodeOutput:
|
||||
latents = av_latent["samples"].unbind()
|
||||
video_latent = av_latent.copy()
|
||||
video_latent["samples"] = latents[0]
|
||||
audio_latent = av_latent.copy()
|
||||
audio_latent["samples"] = latents[1]
|
||||
if "noise_mask" in av_latent:
|
||||
masks = av_latent["noise_mask"]
|
||||
if masks is not None:
|
||||
masks = masks.unbind()
|
||||
video_latent["noise_mask"] = masks[0]
|
||||
audio_latent["noise_mask"] = masks[1]
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyLTXVLatentVideo,
|
||||
LTXVImgToVideo,
|
||||
LTXVImgToVideoInplace,
|
||||
ModelSamplingLTXV,
|
||||
LTXVConditioning,
|
||||
LTXVScheduler,
|
||||
LTXVAddGuide,
|
||||
LTXVPreprocess,
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
]
|
||||
|
||||
|
||||
|
||||
224
comfy_extras/nodes_lt_audio.py
Normal file
224
comfy_extras/nodes_lt_audio.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import torch
|
||||
|
||||
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class LTXVAudioVAELoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
options=folder_paths.get_filename_list("checkpoints"),
|
||||
tooltip="Audio VAE checkpoint to load.",
|
||||
)
|
||||
],
|
||||
outputs=[io.Vae.Output(display_name="Audio VAE")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||
|
||||
|
||||
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model to use for encoding.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="Audio Latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latents = audio_vae.encode(audio)
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": int(audio_vae.sample_rate),
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model used for decoding the latent.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Audio.Output(display_name="Audio")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latent = samples["samples"]
|
||||
if audio_latent.is_nested:
|
||||
audio_latent = audio_latent.unbind()[-1]
|
||||
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"waveform": audio,
|
||||
"sample_rate": int(output_audio_sample_rate),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVEmptyLatentAudio",
|
||||
display_name="LTXV Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Int.Input(
|
||||
"frames_number",
|
||||
default=97,
|
||||
min=1,
|
||||
max=1000,
|
||||
step=1,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="Number of frames.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"frame_rate",
|
||||
default=25,
|
||||
min=1,
|
||||
max=1000,
|
||||
step=1,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="Number of frames per second.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"batch_size",
|
||||
default=1,
|
||||
min=1,
|
||||
max=4096,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
tooltip="The number of latent audio samples in the batch.",
|
||||
),
|
||||
io.Vae.Input(
|
||||
id="audio_vae",
|
||||
display_name="Audio VAE",
|
||||
tooltip="The Audio VAE model to get configuration from.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="Latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
frames_number: int,
|
||||
frame_rate: int,
|
||||
batch_size: int,
|
||||
audio_vae: AudioVAE,
|
||||
) -> io.NodeOutput:
|
||||
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||
|
||||
assert audio_vae is not None, "Audio VAE model is required"
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
audio_latents = torch.zeros(
|
||||
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": sampling_rate,
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXAVTextEncoderLoader",
|
||||
display_name="LTXV Audio Text Encoder Loader",
|
||||
category="advanced/loaders",
|
||||
description="[Recipes]\n\nltxav: gemma 3 12B",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"text_encoder",
|
||||
options=folder_paths.get_filename_list("text_encoders"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
options=folder_paths.get_filename_list("checkpoints"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"device",
|
||||
options=["default", "cpu"],
|
||||
)
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, text_encoder, ckpt_name, device="default"):
|
||||
clip_type = comfy.sd.CLIPType.LTXV
|
||||
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
class LTXVAudioExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LTXVAudioVAELoader,
|
||||
LTXVAudioVAEEncode,
|
||||
LTXVAudioVAEDecode,
|
||||
LTXVEmptyLatentAudio,
|
||||
LTXAVTextEncoderLoader,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ComfyExtension:
|
||||
return LTXVAudioExtension()
|
||||
75
comfy_extras/nodes_lt_upsampler.py
Normal file
75
comfy_extras/nodes_lt_upsampler.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from comfy import model_management
|
||||
import math
|
||||
|
||||
class LTXVLatentUpsampler:
|
||||
"""
|
||||
Upsamples a video latent by a factor of 2.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT",),
|
||||
"upscale_model": ("LATENT_UPSCALE_MODEL",),
|
||||
"vae": ("VAE",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upsample_latent"
|
||||
CATEGORY = "latent/video"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def upsample_latent(
|
||||
self,
|
||||
samples: dict,
|
||||
upscale_model,
|
||||
vae,
|
||||
) -> tuple:
|
||||
"""
|
||||
Upsample the input latent using the provided model.
|
||||
|
||||
Args:
|
||||
samples (dict): Input latent samples
|
||||
upscale_model (LatentUpsampler): Loaded upscale model
|
||||
vae: VAE model for normalization
|
||||
auto_tiling (bool): Whether to automatically tile the input for processing
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing the upsampled latent
|
||||
"""
|
||||
device = model_management.get_torch_device()
|
||||
memory_required = model_management.module_size(upscale_model)
|
||||
|
||||
model_dtype = next(upscale_model.parameters()).dtype
|
||||
latents = samples["samples"]
|
||||
input_dtype = latents.dtype
|
||||
|
||||
memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate
|
||||
model_management.free_memory(memory_required, device)
|
||||
|
||||
try:
|
||||
upscale_model.to(device) # TODO: use the comfy model management system.
|
||||
|
||||
latents = latents.to(dtype=model_dtype, device=device)
|
||||
|
||||
"""Upsample latents without tiling."""
|
||||
latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
|
||||
upsampled_latents = upscale_model(latents)
|
||||
finally:
|
||||
upscale_model.cpu()
|
||||
|
||||
upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
|
||||
upsampled_latents
|
||||
)
|
||||
upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device())
|
||||
return_dict = samples.copy()
|
||||
return_dict["samples"] = upsampled_latents
|
||||
return_dict.pop("noise_mask", None)
|
||||
return (return_dict,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LTXVLatentUpsampler": LTXVLatentUpsampler,
|
||||
}
|
||||
@@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
||||
display_name="Mahiro CFG",
|
||||
category="_for_testing",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
|
||||
@@ -4,11 +4,15 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Literal
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_extras.nodes_latent import reshape_latent_to
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class Blend(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -241,6 +245,353 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
s = s.movedim(1,-1)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
class ResizeType(str, Enum):
|
||||
SCALE_BY = "scale by multiplier"
|
||||
SCALE_DIMENSIONS = "scale dimensions"
|
||||
SCALE_LONGER_DIMENSION = "scale longer dimension"
|
||||
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
|
||||
SCALE_WIDTH = "scale width"
|
||||
SCALE_HEIGHT = "scale height"
|
||||
SCALE_TOTAL_PIXELS = "scale total pixels"
|
||||
MATCH_SIZE = "match size"
|
||||
|
||||
def is_image(input: torch.Tensor) -> bool:
|
||||
# images have 4 dimensions: [batch, height, width, channels]
|
||||
# masks have 3 dimensions: [batch, height, width]
|
||||
return len(input.shape) == 4
|
||||
|
||||
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||
if is_type_image:
|
||||
input = input.movedim(-1, 1)
|
||||
else:
|
||||
input = input.unsqueeze(1)
|
||||
return input
|
||||
|
||||
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||
if is_type_image:
|
||||
input = input.movedim(1, -1)
|
||||
else:
|
||||
input = input.squeeze(1)
|
||||
return input
|
||||
|
||||
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = round(input.shape[-1] * multiplier)
|
||||
height = round(input.shape[-2] * multiplier)
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
|
||||
if width == 0 and height == 0:
|
||||
return input
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
|
||||
if width == 0:
|
||||
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
|
||||
elif height == 0:
|
||||
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = input.shape[-1]
|
||||
height = input.shape[-2]
|
||||
|
||||
if height > width:
|
||||
width = round((width / height) * longer_size)
|
||||
height = longer_size
|
||||
elif width > height:
|
||||
height = round((height / width) * longer_size)
|
||||
width = longer_size
|
||||
else:
|
||||
height = longer_size
|
||||
width = longer_size
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = input.shape[-1]
|
||||
height = input.shape[-2]
|
||||
|
||||
if height < width:
|
||||
width = round((width / height) * shorter_size)
|
||||
height = shorter_size
|
||||
elif width > height:
|
||||
height = round((height / width) * shorter_size)
|
||||
width = shorter_size
|
||||
else:
|
||||
height = shorter_size
|
||||
width = shorter_size
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
total = int(megapixels * 1024 * 1024)
|
||||
|
||||
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
|
||||
width = round(input.shape[-1] * scale_by)
|
||||
height = round(input.shape[-2] * scale_by)
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
match = init_image_mask_input(match, is_image(match))
|
||||
|
||||
width = match.shape[-1]
|
||||
height = match.shape[-2]
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
class ResizeImageMaskNode(io.ComfyNode):
|
||||
|
||||
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
class ResizeTypedDict(TypedDict):
|
||||
resize_type: ResizeType
|
||||
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop: Literal["disabled", "center"]
|
||||
multiplier: float
|
||||
width: int
|
||||
height: int
|
||||
longer_size: int
|
||||
shorter_size: int
|
||||
megapixels: float
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
|
||||
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
|
||||
return io.Schema(
|
||||
node_id="ResizeImageMaskNode",
|
||||
display_name="Resize Image/Mask",
|
||||
category="transform",
|
||||
inputs=[
|
||||
io.MatchType.Input("input", template=template),
|
||||
io.DynamicCombo.Input("resize_type", options=[
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||
io.MultiType.Input("match", [io.Image, io.Mask]),
|
||||
crop_combo,
|
||||
]),
|
||||
]),
|
||||
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
|
||||
],
|
||||
outputs=[io.MatchType.Output(template=template, display_name="resized")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
|
||||
selected_type = resize_type["resize_type"]
|
||||
if selected_type == ResizeType.SCALE_BY:
|
||||
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_DIMENSIONS:
|
||||
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
|
||||
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
|
||||
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
|
||||
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_WIDTH:
|
||||
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
|
||||
elif selected_type == ResizeType.SCALE_HEIGHT:
|
||||
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
|
||||
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
|
||||
elif selected_type == ResizeType.MATCH_SIZE:
|
||||
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
||||
raise ValueError(f"Unsupported resize type: {selected_type}")
|
||||
|
||||
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
|
||||
if len(images) == 0:
|
||||
return None
|
||||
# first, get the max channels count
|
||||
max_channels = max(image.shape[-1] for image in images)
|
||||
# then, pad all images to have the same channels count
|
||||
padded_images: list[torch.Tensor] = []
|
||||
for image in images:
|
||||
if image.shape[-1] < max_channels:
|
||||
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
|
||||
else:
|
||||
padded_images.append(image)
|
||||
# resize all images to be the same size as the first image
|
||||
resized_images: list[torch.Tensor] = []
|
||||
first_image_shape = padded_images[0].shape
|
||||
for image in padded_images:
|
||||
if image.shape[1:] != first_image_shape[1:]:
|
||||
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
|
||||
else:
|
||||
resized_images.append(image)
|
||||
# batch the images in the format [b, h, w, c]
|
||||
return torch.cat(resized_images, dim=0)
|
||||
|
||||
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
|
||||
if len(masks) == 0:
|
||||
return None
|
||||
# resize all masks to be the same size as the first mask
|
||||
resized_masks: list[torch.Tensor] = []
|
||||
first_mask_shape = masks[0].shape
|
||||
for mask in masks:
|
||||
if mask.shape[1:] != first_mask_shape[1:]:
|
||||
mask = init_image_mask_input(mask, is_type_image=False)
|
||||
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
|
||||
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
|
||||
else:
|
||||
resized_masks.append(mask)
|
||||
# batch the masks in the format [b, h, w]
|
||||
return torch.cat(resized_masks, dim=0)
|
||||
|
||||
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
|
||||
if len(latents) == 0:
|
||||
return None
|
||||
samples_out = latents[0].copy()
|
||||
samples_out["batch_index"] = []
|
||||
first_samples = latents[0]["samples"]
|
||||
tensors: list[torch.Tensor] = []
|
||||
for latent in latents:
|
||||
# first, deal with latent tensors
|
||||
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
|
||||
# next, deal with batch_index
|
||||
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
|
||||
samples_out["samples"] = torch.cat(tensors, dim=0)
|
||||
return samples_out
|
||||
|
||||
class BatchImagesNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesNode",
|
||||
display_name="Batch Images",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Autogrow.Input("images", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_images(list(images.values())))
|
||||
|
||||
class BatchMasksNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchMasksNode",
|
||||
display_name="Batch Masks",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Autogrow.Input("masks", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_masks(list(masks.values())))
|
||||
|
||||
class BatchLatentsNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchLatentsNode",
|
||||
display_name="Batch Latents",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.Autogrow.Input("latents", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_latents(list(latents.values())))
|
||||
|
||||
class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||
io.MatchType.Input("input", matchtype_template),
|
||||
prefix="input", min=1, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesMasksLatentsNode",
|
||||
display_name="Batch Images/Masks/Latents",
|
||||
category="util",
|
||||
inputs=[
|
||||
io.Autogrow.Input("inputs", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(id=None, template=matchtype_template)
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||
batched = None
|
||||
values = list(inputs.values())
|
||||
# latents
|
||||
if isinstance(values[0], dict):
|
||||
batched = batch_latents(values)
|
||||
# images
|
||||
elif is_image(values[0]):
|
||||
batched = batch_images(values)
|
||||
# masks
|
||||
else:
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -250,6 +601,11 @@ class PostProcessingExtension(ComfyExtension):
|
||||
Quantize,
|
||||
Sharpen,
|
||||
ImageScaleToTotalPixels,
|
||||
ResizeImageMaskNode,
|
||||
BatchImagesNode,
|
||||
BatchMasksNode,
|
||||
BatchLatentsNode,
|
||||
# BatchImagesMasksLatentsNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> PostProcessingExtension:
|
||||
|
||||
@@ -66,7 +66,7 @@ class Float(io.ComfyNode):
|
||||
display_name="Float",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
|
||||
],
|
||||
outputs=[io.Float.Output()],
|
||||
)
|
||||
|
||||
@@ -78,18 +78,20 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
overlap = 32
|
||||
|
||||
oom = True
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
|
||||
upscale_model.to("cpu")
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
@@ -817,7 +817,7 @@ def get_sample_indices(original_fps,
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
if fixed_start is not None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.7.0"
|
||||
__version__ = "0.9.0"
|
||||
|
||||
46
execution.py
46
execution.py
@@ -79,7 +79,7 @@ class IsChangedCache:
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
@@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
v3_data: io.V3Data = {}
|
||||
hidden_inputs_v3 = {}
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
if is_v3:
|
||||
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||
else:
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs)
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
hidden_inputs_v3 = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
@@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if is_v3:
|
||||
if schema.hidden:
|
||||
if io.Hidden.prompt in schema.hidden:
|
||||
if hidden is not None:
|
||||
if io.Hidden.prompt.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
if io.Hidden.dynprompt in schema.hidden:
|
||||
if io.Hidden.dynprompt.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||
if io.Hidden.extra_pnginfo.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
if io.Hidden.unique_id in schema.hidden:
|
||||
if io.Hidden.unique_id.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||
if io.Hidden.auth_token_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
@@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
pre_execute_cb(index)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
# if is just a class, then assign no state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
@@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
|
||||
# for check_lazy_status, the returned data should include the original key of the input
|
||||
v3_data_lazy = v3_data.copy()
|
||||
v3_data_lazy["create_dynamic_tuple"] = True
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
@@ -599,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
@@ -756,10 +759,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
v3_data = None
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||
obj_class: _io._ComfyNodeBaseInternal
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
|
||||
validate_function_name = "validate_inputs"
|
||||
validate_function = first_real_override(obj_class, validate_function_name)
|
||||
else:
|
||||
@@ -779,10 +785,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
assert extra_info is not None
|
||||
if x not in inputs:
|
||||
if input_category == "required":
|
||||
details = f"{x}" if not v3_data else x.split(".")[-1]
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
@@ -916,8 +923,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if isinstance(input_type, list):
|
||||
combo_options = input_type
|
||||
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||
if input_type == io.Combo.io_type:
|
||||
combo_options = extra_info.get("options", [])
|
||||
else:
|
||||
combo_options = input_type
|
||||
if val not in combo_options:
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
3
main.py
3
main.py
@@ -7,6 +7,7 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args
|
||||
from app.logger import setup_logger
|
||||
from app.assets.scanner import seed_assets
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
@@ -324,6 +325,8 @@ def setup_database():
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if not args.disable_assets_autoscan:
|
||||
seed_assets(["models"], enable_logging=True)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.0.4
|
||||
comfyui_manager==4.0.5
|
||||
|
||||
34
nodes.py
34
nodes.py
@@ -295,7 +295,11 @@ class VAEDecode:
|
||||
DESCRIPTION = "Decodes latent images back into pixel space images."
|
||||
|
||||
def decode(self, vae, samples):
|
||||
images = vae.decode(samples["samples"])
|
||||
latent = samples["samples"]
|
||||
if latent.is_nested:
|
||||
latent = latent.unbind()[0]
|
||||
|
||||
images = vae.decode(latent)
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
@@ -374,14 +378,15 @@ class VAEEncodeForInpaint:
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
||||
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
downscale_ratio = vae.spacial_compression_encode()
|
||||
x = (pixels.shape[1] // downscale_ratio) * downscale_ratio
|
||||
y = (pixels.shape[2] // downscale_ratio) * downscale_ratio
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
||||
pixels = pixels.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
|
||||
x_offset = (pixels.shape[1] % downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % downscale_ratio) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
@@ -970,7 +975,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -1663,8 +1668,6 @@ class LoadImage:
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
excluded_formats = ['MPO']
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
@@ -1692,7 +1695,10 @@ class LoadImage:
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
if len(output_images) > 1 and img.format not in excluded_formats:
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
@@ -1863,6 +1869,7 @@ class ImageBatch:
|
||||
FUNCTION = "batch"
|
||||
|
||||
CATEGORY = "image"
|
||||
DEPRECATED = True
|
||||
|
||||
def batch(self, image1, image2):
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
@@ -2241,8 +2248,10 @@ async def init_external_custom_nodes():
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py":
|
||||
continue
|
||||
if module_path.endswith(".disabled"):
|
||||
continue
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
@@ -2327,6 +2336,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_mochi.py",
|
||||
"nodes_slg.py",
|
||||
"nodes_mahiro.py",
|
||||
"nodes_lt_upsampler.py",
|
||||
"nodes_lt_audio.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_load_3d.py",
|
||||
@@ -2359,6 +2370,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.7.0"
|
||||
version = "0.9.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://www.comfy.org/"
|
||||
@@ -15,12 +15,16 @@ lint.select = [
|
||||
"N805", # invalid-first-argument-name-for-method
|
||||
"S307", # suspicious-eval-usage
|
||||
"S102", # exec
|
||||
"E",
|
||||
"T", # print-usage
|
||||
"W",
|
||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
|
||||
lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"]
|
||||
|
||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||
|
||||
[tool.pylint]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.35.9
|
||||
comfyui-workflow-templates==0.7.64
|
||||
comfyui-embedded-docs==0.3.1
|
||||
comfyui-frontend-package==1.36.14
|
||||
comfyui-workflow-templates==0.8.4
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@@ -21,6 +21,7 @@ psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.6
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
12
server.py
12
server.py
@@ -33,6 +33,8 @@ import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.scanner import seed_assets
|
||||
from app.assets.api.routes import register_assets_system
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@@ -184,7 +186,7 @@ def create_block_external_middleware():
|
||||
else:
|
||||
response = await handler(request)
|
||||
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self' data:; frame-src 'self'; object-src 'self';"
|
||||
return response
|
||||
|
||||
return block_external_middleware
|
||||
@@ -235,6 +237,7 @@ class PromptServer():
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
register_assets_system(self.app, self.user_manager)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@@ -324,7 +327,7 @@ class PromptServer():
|
||||
@routes.get("/models/{folder}")
|
||||
async def get_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = folder_paths.get_filename_list(folder)
|
||||
return web.json_response(files)
|
||||
@@ -579,7 +582,7 @@ class PromptServer():
|
||||
folder_name = request.match_info.get("folder_name", None)
|
||||
if folder_name is None:
|
||||
return web.Response(status=404)
|
||||
if not "filename" in request.rel_url.query:
|
||||
if "filename" not in request.rel_url.query:
|
||||
return web.Response(status=404)
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
@@ -593,7 +596,7 @@ class PromptServer():
|
||||
if out is None:
|
||||
return web.Response(status=404)
|
||||
dt = json.loads(out)
|
||||
if not "__metadata__" in dt:
|
||||
if "__metadata__" not in dt:
|
||||
return web.Response(status=404)
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@@ -683,6 +686,7 @@ class PromptServer():
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
seed_assets(["models"])
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
|
||||
@@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
||||
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||
|
||||
# Layer 2 should NOT be quantized
|
||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||
|
||||
# Layer 3 should be quantized
|
||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
||||
self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||
|
||||
# Verify scales were loaded
|
||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
||||
self.assertEqual(model.layer1.weight._params.scale.item(), 2.0)
|
||||
self.assertEqual(model.layer3.weight._params.scale.item(), 1.5)
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
@@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
state_dict2 = model.state_dict()
|
||||
|
||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
||||
self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
|
||||
self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
|
||||
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||
|
||||
# Verify non-quantized layers are standard tensors
|
||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
def has_gpu():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
from comfy.cli_args import args
|
||||
if not has_gpu():
|
||||
args.cpu = True
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
|
||||
|
||||
class TestQuantizedTensor(unittest.TestCase):
|
||||
"""Test the QuantizedTensor subclass with FP8 layout"""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._layout_params['scale'], scale)
|
||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test explicit dequantization"""
|
||||
|
||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(3.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
dequantized = qt.dequantize()
|
||||
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
def test_from_float(self):
|
||||
"""Test creating QuantizedTensor from float tensor"""
|
||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
qt = QuantizedTensor.from_float(
|
||||
float_tensor,
|
||||
"TensorCoreFP8Layout",
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt.shape, (64, 32))
|
||||
|
||||
# Verify dequantization gives approximately original values
|
||||
dequantized = qt.dequantize()
|
||||
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
||||
self.assertLess(mean_rel_error, 0.1)
|
||||
|
||||
|
||||
class TestGenericUtilities(unittest.TestCase):
|
||||
"""Test generic utility operations"""
|
||||
|
||||
def test_detach(self):
|
||||
"""Test detach operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Detach should return a new QuantizedTensor
|
||||
qt_detached = qt.detach()
|
||||
|
||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||
self.assertEqual(qt_detached.shape, qt.shape)
|
||||
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
def test_clone(self):
|
||||
"""Test clone operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Clone should return a new QuantizedTensor
|
||||
qt_cloned = qt.clone()
|
||||
|
||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Verify it's a deep copy
|
||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||
|
||||
@unittest.skipUnless(has_gpu(), "GPU not available")
|
||||
def test_to_device(self):
|
||||
"""Test device transfer"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Moving to same device should work (CPU to CPU)
|
||||
qt_cpu = qt.to('cpu')
|
||||
|
||||
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
||||
self.assertEqual(qt_cpu.device.type, 'cpu')
|
||||
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
||||
|
||||
|
||||
class TestTensorCoreFP8Layout(unittest.TestCase):
|
||||
"""Test the TensorCoreFP8Layout implementation"""
|
||||
|
||||
def test_quantize(self):
|
||||
"""Test quantization method"""
|
||||
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qdata.shape, float_tensor.shape)
|
||||
self.assertIn('scale', layout_params)
|
||||
self.assertIn('orig_dtype', layout_params)
|
||||
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test dequantization method"""
|
||||
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
||||
scale = torch.tensor(1.0)
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
||||
|
||||
# Should approximately match original
|
||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||
|
||||
|
||||
class TestFallbackMechanism(unittest.TestCase):
|
||||
"""Test fallback for unsupported operations"""
|
||||
|
||||
def test_unsupported_op_dequantizes(self):
|
||||
"""Test that unsupported operations fall back to dequantization"""
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create quantized tensor
|
||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
a_q = QuantizedTensor.from_float(
|
||||
a_fp32,
|
||||
"TensorCoreFP8Layout",
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
# Call an operation that doesn't have a registered handler
|
||||
# For example, torch.abs
|
||||
result = torch.abs(a_q)
|
||||
|
||||
# Should work via fallback (dequantize → abs → return)
|
||||
self.assertNotIsInstance(result, QuantizedTensor)
|
||||
expected = torch.abs(a_fp32)
|
||||
# FP8 introduces quantization error, so use loose tolerance
|
||||
mean_error = (result - expected).abs().mean()
|
||||
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -19,6 +19,7 @@ class TestJobStatus:
|
||||
assert JobStatus.IN_PROGRESS == 'in_progress'
|
||||
assert JobStatus.COMPLETED == 'completed'
|
||||
assert JobStatus.FAILED == 'failed'
|
||||
assert JobStatus.CANCELLED == 'cancelled'
|
||||
|
||||
def test_all_contains_all_statuses(self):
|
||||
"""ALL should contain all status values."""
|
||||
@@ -26,7 +27,8 @@ class TestJobStatus:
|
||||
assert JobStatus.IN_PROGRESS in JobStatus.ALL
|
||||
assert JobStatus.COMPLETED in JobStatus.ALL
|
||||
assert JobStatus.FAILED in JobStatus.ALL
|
||||
assert len(JobStatus.ALL) == 4
|
||||
assert JobStatus.CANCELLED in JobStatus.ALL
|
||||
assert len(JobStatus.ALL) == 5
|
||||
|
||||
|
||||
class TestIsPreviewable:
|
||||
@@ -336,6 +338,40 @@ class TestNormalizeHistoryItem:
|
||||
assert job['execution_error']['node_type'] == 'KSampler'
|
||||
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
|
||||
|
||||
def test_cancelled_job(self):
|
||||
"""Cancelled/interrupted history item should have cancelled status."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-cancelled',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890000},
|
||||
['node1'],
|
||||
),
|
||||
'status': {
|
||||
'status_str': 'error',
|
||||
'completed': False,
|
||||
'messages': [
|
||||
('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}),
|
||||
('execution_interrupted', {
|
||||
'prompt_id': 'prompt-cancelled',
|
||||
'node_id': '5',
|
||||
'node_type': 'KSampler',
|
||||
'executed': ['1', '2', '3'],
|
||||
'timestamp': 1234567891000,
|
||||
})
|
||||
]
|
||||
},
|
||||
'outputs': {},
|
||||
}
|
||||
|
||||
job = normalize_history_item('prompt-cancelled', history_item)
|
||||
assert job['status'] == 'cancelled'
|
||||
assert job['execution_start_time'] == 1234567890500
|
||||
assert job['execution_end_time'] == 1234567891000
|
||||
# Cancelled jobs should not have execution_error set
|
||||
assert 'execution_error' not in job
|
||||
|
||||
def test_include_outputs(self):
|
||||
"""When include_outputs=True, should include full output data."""
|
||||
history_item = {
|
||||
|
||||
Reference in New Issue
Block a user