mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 13:10:02 +00:00
Compare commits
165 Commits
feature/fr
...
glsl-cloud
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5332ca60e4 | ||
|
|
40b247f046 | ||
|
|
e1ede29d82 | ||
|
|
df1e5e8514 | ||
|
|
dc9822b7df | ||
|
|
712efb466b | ||
|
|
726af73867 | ||
|
|
831351a29e | ||
|
|
e1add563f9 | ||
|
|
8902907d7a | ||
|
|
e03fe8b591 | ||
|
|
ae79e33345 | ||
|
|
117e214354 | ||
|
|
4a93a62371 | ||
|
|
66c18522fb | ||
|
|
e5ae670a40 | ||
|
|
3fe61cedda | ||
|
|
2a4328d639 | ||
|
|
d297a749a2 | ||
|
|
2b7cc7e3b6 | ||
|
|
4993411fd9 | ||
|
|
2c7cef4a23 | ||
|
|
76a7fa96db | ||
|
|
cdcf4119b3 | ||
|
|
dbe70b6821 | ||
|
|
00fff6019e | ||
|
|
123a7874a9 | ||
|
|
f719f9c062 | ||
|
|
fe053ba5eb | ||
|
|
6648ab68bc | ||
|
|
6615db925c | ||
|
|
8ca842a8ed | ||
|
|
c1b63a7e78 | ||
|
|
349a636a2b | ||
|
|
a4be04c5d7 | ||
|
|
baf8c87455 | ||
|
|
62315fbb15 | ||
|
|
a0302cc6a8 | ||
|
|
f350a84261 | ||
|
|
3760d74005 | ||
|
|
9bf5aa54db | ||
|
|
5ff4fdedba | ||
|
|
17e7df43d1 | ||
|
|
039955c527 | ||
|
|
6a26328842 | ||
|
|
204e65b8dc | ||
|
|
a831c19b70 | ||
|
|
eba6c940fd | ||
|
|
a1c101f861 | ||
|
|
c2d7f07dbf | ||
|
|
458292fef0 | ||
|
|
6555dc65b8 | ||
|
|
2b70ab9ad0 | ||
|
|
00efcc6cd0 | ||
|
|
cb459573c8 | ||
|
|
35183543e0 | ||
|
|
a246cc02b2 | ||
|
|
a50c32d63f | ||
|
|
6125b80979 | ||
|
|
c8fcbd66ee | ||
|
|
26dd7eb421 | ||
|
|
e77b34dfea | ||
|
|
c2d229a786 | ||
|
|
ef73070ea4 | ||
|
|
d30c609f5a | ||
|
|
5087f1d497 | ||
|
|
a31681564d | ||
|
|
855849c658 | ||
|
|
fe2511468d | ||
|
|
3be0175166 | ||
|
|
b8315e66cb | ||
|
|
ab1050bec3 | ||
|
|
fb23935c11 | ||
|
|
85fc35e8fa | ||
|
|
223364743c | ||
|
|
affe881354 | ||
|
|
f5030e26fd | ||
|
|
66e1b07402 | ||
|
|
be4345d1c9 | ||
|
|
3c1a1a2df8 | ||
|
|
ba5bf3f1a8 | ||
|
|
c05a08ae66 | ||
|
|
de9ada6a41 | ||
|
|
37f711d4a1 | ||
|
|
dd86b15521 | ||
|
|
021ba20719 | ||
|
|
b60be02aaf | ||
|
|
2b5da3b72e | ||
|
|
794d05bdb1 | ||
|
|
361b9a82a3 | ||
|
|
667a1b8878 | ||
|
|
32621c6a11 | ||
|
|
f8acd9c402 | ||
|
|
873de5f37a | ||
|
|
aa6f7a83bb | ||
|
|
3c40ee0f02 | ||
|
|
43034b6881 | ||
|
|
bb048d4aaa | ||
|
|
7c1f02d1fa | ||
|
|
292a5918f4 | ||
|
|
0050b66a0b | ||
|
|
0c313f5293 | ||
|
|
1fcf9dca18 | ||
|
|
3b790d24d6 | ||
|
|
6ea8c128a3 | ||
|
|
6e469a3f35 | ||
|
|
b8f848bfe3 | ||
|
|
4064062e7d | ||
|
|
92b2b7198a | ||
|
|
8aabe2403e | ||
|
|
309c3e4ec0 | ||
|
|
23591d4388 | ||
|
|
0167653781 | ||
|
|
c3d07bec6d | ||
|
|
0a7993729c | ||
|
|
bbe2c13a70 | ||
|
|
59b955ff54 | ||
|
|
1263d6fe88 | ||
|
|
3aace5c8dc | ||
|
|
b0d9708974 | ||
|
|
23572c6314 | ||
|
|
d809ef8fb1 | ||
|
|
a4317314d2 | ||
|
|
c9b633d84f | ||
|
|
aaea976f36 | ||
|
|
cee092213e | ||
|
|
3da0e9c367 | ||
|
|
9fa8202620 | ||
|
|
b4438c9baf | ||
|
|
1711020904 | ||
|
|
d9b8567547 | ||
|
|
6c5f906bf2 | ||
|
|
4f5bd39b1c | ||
|
|
dcff27fe3f | ||
|
|
09725967cf | ||
|
|
5f62440fbb | ||
|
|
ac91c340f4 | ||
|
|
2db3b0ff90 | ||
|
|
6516ab335d | ||
|
|
ad53e78f11 | ||
|
|
29011ba87e | ||
|
|
cd4985e2f3 | ||
|
|
bfe31d0b9d | ||
|
|
2129e7d278 | ||
|
|
7ee77ff038 | ||
|
|
26c5bbb875 | ||
|
|
a97c98068f | ||
|
|
635406e283 | ||
|
|
ed6002cb60 | ||
|
|
bc72d7f8d1 | ||
|
|
aef4e13588 | ||
|
|
4e6a1b66a9 | ||
|
|
9cf299a9f9 | ||
|
|
e89b22993a | ||
|
|
55bd606e92 | ||
|
|
cc30293d65 | ||
|
|
866d863128 | ||
|
|
79cdbc81cb | ||
|
|
f443b9f2ca | ||
|
|
4e3038114a | ||
|
|
bbb8864778 | ||
|
|
d7f3241bf6 | ||
|
|
09a2e67151 | ||
|
|
0fd1b78736 | ||
|
|
8490eedadf |
6
.github/workflows/release-stable-all.yml
vendored
6
.github/workflows/release-stable-all.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
python_patch: "11"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
@@ -65,11 +65,11 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
name: "Release AMD ROCm 7.2"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
cache_tag: "rocm72"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@@ -7,6 +7,8 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@@ -106,3 +108,37 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "11"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -208,7 +208,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
@@ -227,7 +227,7 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
@@ -8,6 +11,9 @@ import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
from app.assets.scanner import seed_assets
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# Note to any custom node developers reading this code:
|
||||
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
@@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
file_size = os.path.getsize(abs_path)
|
||||
logging.info(
|
||||
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||
abs_path,
|
||||
file_size,
|
||||
file_size / (1024 * 1024),
|
||||
content_type,
|
||||
filename,
|
||||
)
|
||||
|
||||
async def file_sender():
|
||||
chunk_size = 64 * 1024
|
||||
with open(abs_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return web.Response(
|
||||
body=file_sender(),
|
||||
content_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": cd,
|
||||
"Content-Length": str(file_size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
if not valid_roots:
|
||||
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
@@ -8,9 +7,9 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
@@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@@ -42,6 +66,7 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@@ -94,7 +119,11 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@@ -171,12 +230,14 @@ def list_asset_infos_page(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@@ -265,3 +814,163 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@@ -1,13 +1,33 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@@ -63,7 +100,6 @@ def list_assets(
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
@@ -76,7 +112,12 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
|
||||
@@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
@@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
@@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._node_replace import NodeReplace
|
||||
|
||||
REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register_node_replacement(node_replace: NodeReplace):
|
||||
REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def registered_as_dict():
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items()
|
||||
}
|
||||
|
||||
class NodeReplaceManager:
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(registered_as_dict())
|
||||
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
@@ -0,0 +1,44 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Brightness slider -100..100
|
||||
uniform float u_float1; // Contrast slider -100..100
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float MID_GRAY = 0.18; // 18% reflectance
|
||||
|
||||
// sRGB gamma 2.2 approximation
|
||||
vec3 srgbToLinear(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(2.2));
|
||||
}
|
||||
|
||||
vec3 linearToSrgb(vec3 c) {
|
||||
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||
}
|
||||
|
||||
float mapBrightness(float b) {
|
||||
return clamp(b / 100.0, -1.0, 1.0);
|
||||
}
|
||||
|
||||
float mapContrast(float c) {
|
||||
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 orig = texture(u_image0, v_texCoord);
|
||||
|
||||
float brightness = mapBrightness(u_float0);
|
||||
float contrast = mapContrast(u_float1);
|
||||
|
||||
vec3 lin = srgbToLinear(orig.rgb);
|
||||
|
||||
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||
|
||||
// Convert back to sRGB
|
||||
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||
|
||||
fragColor = vec4(result, orig.a);
|
||||
}
|
||||
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
@@ -0,0 +1,72 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Mode
|
||||
uniform float u_float0; // Amount (0 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MODE_LINEAR = 0;
|
||||
const int MODE_RADIAL = 1;
|
||||
const int MODE_BARREL = 2;
|
||||
const int MODE_SWIRL = 3;
|
||||
const int MODE_DIAGONAL = 4;
|
||||
|
||||
const float AMOUNT_SCALE = 0.0005;
|
||||
const float RADIAL_MULT = 4.0;
|
||||
const float BARREL_MULT = 8.0;
|
||||
const float INV_SQRT2 = 0.70710678118;
|
||||
|
||||
void main() {
|
||||
vec2 uv = v_texCoord;
|
||||
vec4 original = texture(u_image0, uv);
|
||||
|
||||
float amount = u_float0 * AMOUNT_SCALE;
|
||||
|
||||
if (amount < 0.000001) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
// Aspect-corrected coordinates for circular effects
|
||||
float aspect = u_resolution.x / u_resolution.y;
|
||||
vec2 centered = uv - 0.5;
|
||||
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||
float r = length(corrected);
|
||||
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||
vec2 offset = vec2(0.0);
|
||||
|
||||
if (u_int0 == MODE_LINEAR) {
|
||||
// Horizontal shift (no aspect correction needed)
|
||||
offset = vec2(amount, 0.0);
|
||||
}
|
||||
else if (u_int0 == MODE_RADIAL) {
|
||||
// Outward from center, stronger at edges
|
||||
offset = dir * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_BARREL) {
|
||||
// Lens distortion simulation (r² falloff)
|
||||
offset = dir * r * r * amount * BARREL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_SWIRL) {
|
||||
// Perpendicular to radial (rotational aberration)
|
||||
vec2 perp = vec2(-dir.y, dir.x);
|
||||
offset = perp * r * amount * RADIAL_MULT;
|
||||
offset.x /= aspect; // Convert back to UV space
|
||||
}
|
||||
else if (u_int0 == MODE_DIAGONAL) {
|
||||
// 45° offset (no aspect correction needed)
|
||||
offset = vec2(amount, amount) * INV_SQRT2;
|
||||
}
|
||||
|
||||
float red = texture(u_image0, uv + offset).r;
|
||||
float green = original.g;
|
||||
float blue = texture(u_image0, uv - offset).b;
|
||||
|
||||
fragColor = vec4(red, green, blue, original.a);
|
||||
}
|
||||
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
@@ -0,0 +1,78 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // temperature (-100 to 100)
|
||||
uniform float u_float1; // tint (-100 to 100)
|
||||
uniform float u_float2; // vibrance (-100 to 100)
|
||||
uniform float u_float3; // saturation (-100 to 100)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const float INPUT_SCALE = 0.01;
|
||||
const float TEMP_TINT_PRIMARY = 0.3;
|
||||
const float TEMP_TINT_SECONDARY = 0.15;
|
||||
const float VIBRANCE_BOOST = 2.0;
|
||||
const float SATURATION_BOOST = 2.0;
|
||||
const float SKIN_PROTECTION = 0.5;
|
||||
const float EPSILON = 0.001;
|
||||
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
// Scale inputs: -100/100 → -1/1
|
||||
float temperature = u_float0 * INPUT_SCALE;
|
||||
float tint = u_float1 * INPUT_SCALE;
|
||||
float vibrance = u_float2 * INPUT_SCALE;
|
||||
float saturation = u_float3 * INPUT_SCALE;
|
||||
|
||||
// Temperature (warm/cool): positive = warm, negative = cool
|
||||
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||
|
||||
// Tint (green/magenta): positive = green, negative = magenta
|
||||
color.g += tint * TEMP_TINT_PRIMARY;
|
||||
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||
|
||||
// Single clamp after temperature/tint
|
||||
color = clamp(color, 0.0, 1.0);
|
||||
|
||||
// Vibrance with skin protection
|
||||
if (vibrance != 0.0) {
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float sat = maxC - minC;
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
|
||||
if (vibrance < 0.0) {
|
||||
// Desaturate: -100 → gray
|
||||
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||
} else {
|
||||
// Boost less saturated colors more
|
||||
float vibranceAmt = vibrance * (1.0 - sat);
|
||||
|
||||
// Branchless skin tone protection
|
||||
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||
|
||||
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||
}
|
||||
}
|
||||
|
||||
// Saturation
|
||||
if (saturation != 0.0) {
|
||||
float gray = dot(color, LUMA_WEIGHTS);
|
||||
float satMix = saturation < 0.0
|
||||
? 1.0 + saturation // -100 → gray
|
||||
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||
color = mix(vec3(gray), color, satMix);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
@@ -0,0 +1,94 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int MAX_RADIUS = 20;
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
// Perceptual luminance
|
||||
float getLuminance(vec3 rgb) {
|
||||
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
}
|
||||
|
||||
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||
float sigmaSpatial, float sigmaColor)
|
||||
{
|
||||
vec4 center = texture(u_image0, uv);
|
||||
vec3 centerRGB = center.rgb;
|
||||
|
||||
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||
|
||||
vec3 sumRGB = vec3(0.0);
|
||||
float sumWeight = 0.0;
|
||||
|
||||
int step = max(u_int0, 1);
|
||||
float radius2 = float(radius * radius);
|
||||
|
||||
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||
if (dy < -radius || dy > radius) continue;
|
||||
if (abs(dy) % step != 0) continue;
|
||||
|
||||
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||
if (dx < -radius || dx > radius) continue;
|
||||
if (abs(dx) % step != 0) continue;
|
||||
|
||||
vec2 offset = vec2(float(dx), float(dy));
|
||||
float dist2 = dot(offset, offset);
|
||||
if (dist2 > radius2) continue;
|
||||
|
||||
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||
|
||||
// Spatial Gaussian
|
||||
float spatialWeight = exp(dist2 * invSpatial2);
|
||||
|
||||
// Perceptual color distance (weighted RGB)
|
||||
vec3 diff = sampleRGB - centerRGB;
|
||||
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||
float colorWeight = exp(colorDist * invColor2);
|
||||
|
||||
float w = spatialWeight * colorWeight;
|
||||
sumRGB += sampleRGB * w;
|
||||
sumWeight += w;
|
||||
}
|
||||
}
|
||||
|
||||
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||
return vec4(resultRGB, center.a); // preserve center alpha
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||
|
||||
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||
int radius = int(radiusF + 0.5);
|
||||
|
||||
if (radius == 0) {
|
||||
fragColor = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Edge threshold → color sigma
|
||||
// Squared curve for better low-end control
|
||||
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||
t *= t;
|
||||
float sigmaColor = mix(0.01, 0.5, t);
|
||||
|
||||
// Spatial sigma tied to radius
|
||||
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||
|
||||
fragColor = bilateralFilter(
|
||||
v_texCoord,
|
||||
texelSize,
|
||||
radius,
|
||||
sigmaSpatial,
|
||||
sigmaColor
|
||||
);
|
||||
}
|
||||
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
@@ -0,0 +1,124 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// High-quality integer hash (pcg-like)
|
||||
uint pcg(uint v) {
|
||||
uint state = v * 747796405u + 2891336453u;
|
||||
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||
return (word >> 22u) ^ word;
|
||||
}
|
||||
|
||||
// 2D -> 1D hash input
|
||||
uint hash2d(uvec2 p) {
|
||||
return pcg(p.x + pcg(p.y));
|
||||
}
|
||||
|
||||
// Hash to float [0, 1]
|
||||
float hashf(uvec2 p) {
|
||||
return float(hash2d(p)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Hash to float with offset (for RGB channels)
|
||||
float hashf(uvec2 p, uint offset) {
|
||||
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||
}
|
||||
|
||||
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||
// Using simple approximation: average of multiple samples
|
||||
float toGaussian(uvec2 p) {
|
||||
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||
}
|
||||
|
||||
float toGaussian(uvec2 p, uint offset) {
|
||||
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||
return (sum - 2.0) * 0.7;
|
||||
}
|
||||
|
||||
// Smooth noise with better interpolation
|
||||
float smoothNoise(vec2 p) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
// Quintic interpolation (less banding than cubic)
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
float smoothNoise(vec2 p, uint offset) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
|
||||
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||
|
||||
uvec2 ui = uvec2(i);
|
||||
float a = toGaussian(ui, offset);
|
||||
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||
|
||||
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// Luminance (Rec.709)
|
||||
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||
|
||||
// Grain UV (resolution-independent)
|
||||
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||
uvec2 grainPixel = uvec2(grainUV);
|
||||
|
||||
float g;
|
||||
vec3 grainRGB;
|
||||
|
||||
if (u_int0 == 1) {
|
||||
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||
g = toGaussian(grainPixel);
|
||||
grainRGB = vec3(
|
||||
toGaussian(grainPixel, 100u),
|
||||
toGaussian(grainPixel, 200u),
|
||||
toGaussian(grainPixel, 300u)
|
||||
);
|
||||
} else {
|
||||
// Smooth mode: interpolated with quintic curve
|
||||
g = smoothNoise(grainUV);
|
||||
grainRGB = vec3(
|
||||
smoothNoise(grainUV, 100u),
|
||||
smoothNoise(grainUV, 200u),
|
||||
smoothNoise(grainUV, 300u)
|
||||
);
|
||||
}
|
||||
|
||||
// Luminance weighting (less grain in highlights)
|
||||
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||
|
||||
// Strength
|
||||
float strength = u_float0 * 0.15;
|
||||
|
||||
// Color vs monochrome grain
|
||||
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||
|
||||
color.rgb += grainColor * strength * lumWeight;
|
||||
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||
}
|
||||
133
blueprints/.glsl/Glow_30.frag
Normal file
133
blueprints/.glsl/Glow_30.frag
Normal file
@@ -0,0 +1,133 @@
|
||||
#version 300 es
|
||||
precision mediump float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blend mode
|
||||
uniform int u_int1; // Color tint
|
||||
uniform float u_float0; // Intensity
|
||||
uniform float u_float1; // Radius
|
||||
uniform float u_float2; // Threshold
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
const int BLEND_ADD = 0;
|
||||
const int BLEND_SCREEN = 1;
|
||||
const int BLEND_SOFT = 2;
|
||||
const int BLEND_OVERLAY = 3;
|
||||
const int BLEND_LIGHTEN = 4;
|
||||
|
||||
const float GOLDEN_ANGLE = 2.39996323;
|
||||
const int MAX_SAMPLES = 48;
|
||||
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||
|
||||
float hash(vec2 p) {
|
||||
p = fract(p * vec2(123.34, 456.21));
|
||||
p += dot(p, p + 45.32);
|
||||
return fract(p.x * p.y);
|
||||
}
|
||||
|
||||
vec3 hexToRgb(int h) {
|
||||
return vec3(
|
||||
float((h >> 16) & 255),
|
||||
float((h >> 8) & 255),
|
||||
float(h & 255)
|
||||
) * (1.0 / 255.0);
|
||||
}
|
||||
|
||||
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||
if (mode == BLEND_SCREEN) {
|
||||
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||
}
|
||||
if (mode == BLEND_SOFT) {
|
||||
return mix(
|
||||
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||
step(0.5, glow)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_OVERLAY) {
|
||||
return mix(
|
||||
2.0 * base * glow,
|
||||
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||
step(0.5, base)
|
||||
);
|
||||
}
|
||||
if (mode == BLEND_LIGHTEN) {
|
||||
return max(base, glow);
|
||||
}
|
||||
return base + glow;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float intensity = u_float0 * 0.05;
|
||||
float radius = u_float1 * u_float1 * 0.012;
|
||||
|
||||
if (intensity < 0.001 || radius < 0.1) {
|
||||
fragColor = original;
|
||||
return;
|
||||
}
|
||||
|
||||
float threshold = 1.0 - u_float2 * 0.01;
|
||||
float t0 = threshold - 0.15;
|
||||
float t1 = threshold + 0.15;
|
||||
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius2 = radius * radius;
|
||||
|
||||
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||
|
||||
float noise = hash(gl_FragCoord.xy);
|
||||
float angleOffset = noise * GOLDEN_ANGLE;
|
||||
float radiusJitter = 0.85 + noise * 0.3;
|
||||
|
||||
float ca = cos(GOLDEN_ANGLE);
|
||||
float sa = sin(GOLDEN_ANGLE);
|
||||
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||
|
||||
vec3 glow = vec3(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
// Center tap
|
||||
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||
glow += original.rgb * centerMask * 2.0;
|
||||
totalWeight += 2.0;
|
||||
|
||||
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||
if (i >= samples) break;
|
||||
|
||||
float fi = float(i);
|
||||
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||
|
||||
vec2 offset = dir * dist * texelSize;
|
||||
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||
|
||||
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||
w = max(w, 0.0);
|
||||
w *= w;
|
||||
|
||||
glow += c * mask * w;
|
||||
totalWeight += w;
|
||||
|
||||
dir = vec2(
|
||||
dir.x * ca - dir.y * sa,
|
||||
dir.x * sa + dir.y * ca
|
||||
);
|
||||
}
|
||||
|
||||
glow *= intensity / max(totalWeight, 0.001);
|
||||
|
||||
if (u_int1 > 0) {
|
||||
glow *= hexToRgb(u_int1);
|
||||
}
|
||||
|
||||
vec3 result = blend(original.rgb, glow, u_int0);
|
||||
result += (noise - 0.5) * (1.0 / 255.0);
|
||||
|
||||
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||
}
|
||||
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
@@ -0,0 +1,222 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||
uniform float u_float0; // Hue (-180 to 180)
|
||||
uniform float u_float1; // Saturation (-100 to 100)
|
||||
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
// Color range modes
|
||||
const int MODE_MASTER = 0;
|
||||
const int MODE_RED = 1;
|
||||
const int MODE_YELLOW = 2;
|
||||
const int MODE_GREEN = 3;
|
||||
const int MODE_CYAN = 4;
|
||||
const int MODE_BLUE = 5;
|
||||
const int MODE_MAGENTA = 6;
|
||||
const int MODE_COLORIZE = 7;
|
||||
|
||||
// Color space modes
|
||||
const int COLORSPACE_HSL = 0;
|
||||
const int COLORSPACE_HSB = 1;
|
||||
|
||||
const float EPSILON = 0.0001;
|
||||
|
||||
//=============================================================================
|
||||
// RGB <-> HSL Conversions
|
||||
//=============================================================================
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = 0.0;
|
||||
float l = (maxC + minC) * 0.5;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
s = l < 0.5
|
||||
? delta / (maxC + minC)
|
||||
: delta / (2.0 - maxC - minC);
|
||||
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
t = fract(t);
|
||||
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 0.5) return q;
|
||||
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||
|
||||
float q = hsl.z < 0.5
|
||||
? hsl.z * (1.0 + hsl.y)
|
||||
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||
float p = 2.0 * hsl.z - q;
|
||||
|
||||
return vec3(
|
||||
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||
hue2rgb(p, q, hsl.x),
|
||||
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||
);
|
||||
}
|
||||
|
||||
vec3 rgb2hsb(vec3 c) {
|
||||
float maxC = max(max(c.r, c.g), c.b);
|
||||
float minC = min(min(c.r, c.g), c.b);
|
||||
float delta = maxC - minC;
|
||||
|
||||
float h = 0.0;
|
||||
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||
float b = maxC;
|
||||
|
||||
if (delta > EPSILON) {
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / delta + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / delta + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
}
|
||||
|
||||
return vec3(h, s, b);
|
||||
}
|
||||
|
||||
vec3 hsb2rgb(vec3 hsb) {
|
||||
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Color Range Weight Calculation
|
||||
//=============================================================================
|
||||
|
||||
float hueDistance(float a, float b) {
|
||||
float d = abs(a - b);
|
||||
return min(d, 1.0 - d);
|
||||
}
|
||||
|
||||
float getHueWeight(float hue, float center, float overlap) {
|
||||
float baseWidth = 1.0 / 6.0;
|
||||
float feather = baseWidth * overlap;
|
||||
|
||||
float d = hueDistance(hue, center);
|
||||
|
||||
float inner = baseWidth * 0.5;
|
||||
float outer = inner + feather;
|
||||
|
||||
return 1.0 - smoothstep(inner, outer, d);
|
||||
}
|
||||
|
||||
float getModeWeight(float hue, int mode, float overlap) {
|
||||
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||
|
||||
if (mode == MODE_RED) {
|
||||
return max(
|
||||
getHueWeight(hue, 0.0, overlap),
|
||||
getHueWeight(hue, 1.0, overlap)
|
||||
);
|
||||
}
|
||||
|
||||
float center = float(mode - 1) / 6.0;
|
||||
return getHueWeight(hue, center, overlap);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Adjustment Functions
|
||||
//=============================================================================
|
||||
|
||||
float adjustLightness(float l, float amount) {
|
||||
return amount > 0.0
|
||||
? l + (1.0 - l) * amount
|
||||
: l + l * amount;
|
||||
}
|
||||
|
||||
float adjustBrightness(float b, float amount) {
|
||||
return clamp(b + amount, 0.0, 1.0);
|
||||
}
|
||||
|
||||
float adjustSaturation(float s, float amount) {
|
||||
return amount > 0.0
|
||||
? s + (1.0 - s) * amount
|
||||
: s + s * amount;
|
||||
}
|
||||
|
||||
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||
float l = adjustLightness(lum, light);
|
||||
|
||||
vec3 hsl = vec3(fract(hue), clamp(abs(sat), 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||
return hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Main
|
||||
//=============================================================================
|
||||
|
||||
void main() {
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == MODE_COLORIZE) {
|
||||
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||
fragColor = vec4(result, original.a);
|
||||
return;
|
||||
}
|
||||
|
||||
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||
? rgb2hsl(original.rgb)
|
||||
: rgb2hsb(original.rgb);
|
||||
|
||||
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||
|
||||
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||
weight = 0.0;
|
||||
}
|
||||
|
||||
if (weight > EPSILON) {
|
||||
float h = fract(hsx.x + hueShift * weight);
|
||||
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||
float v = (u_int1 == COLORSPACE_HSL)
|
||||
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||
|
||||
vec3 adjusted = vec3(h, s, v);
|
||||
result = (u_int1 == COLORSPACE_HSL)
|
||||
? hsl2rgb(adjusted)
|
||||
: hsb2rgb(adjusted);
|
||||
} else {
|
||||
result = original.rgb;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, original.a);
|
||||
}
|
||||
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
@@ -0,0 +1,111 @@
|
||||
#version 300 es
|
||||
#pragma passes 2
|
||||
precision highp float;
|
||||
|
||||
// Blur type constants
|
||||
const int BLUR_GAUSSIAN = 0;
|
||||
const int BLUR_BOX = 1;
|
||||
const int BLUR_RADIAL = 2;
|
||||
|
||||
// Radial blur config
|
||||
const int RADIAL_SAMPLES = 12;
|
||||
const float RADIAL_STRENGTH = 0.0003;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||
uniform float u_float0; // Blur radius/amount
|
||||
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texelSize = 1.0 / u_resolution;
|
||||
float radius = max(u_float0, 0.0);
|
||||
|
||||
// Radial (angular) blur - single pass, doesn't use separable
|
||||
if (u_int0 == BLUR_RADIAL) {
|
||||
// Only execute on first pass
|
||||
if (u_pass > 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec2 center = vec2(0.5);
|
||||
vec2 dir = v_texCoord - center;
|
||||
float dist = length(dir);
|
||||
|
||||
if (dist < 1e-4) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
vec4 sum = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float angleStep = radius * RADIAL_STRENGTH;
|
||||
|
||||
dir /= dist;
|
||||
|
||||
float cosStep = cos(angleStep);
|
||||
float sinStep = sin(angleStep);
|
||||
|
||||
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||
vec2 rotDir = vec2(
|
||||
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||
);
|
||||
|
||||
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||
vec2 uv = center + rotDir * dist;
|
||||
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||
sum += texture(u_image0, uv) * w;
|
||||
totalWeight += w;
|
||||
|
||||
rotDir = vec2(
|
||||
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||
rotDir.x * sinStep + rotDir.y * cosStep
|
||||
);
|
||||
}
|
||||
|
||||
fragColor0 = sum / max(totalWeight, 0.001);
|
||||
return;
|
||||
}
|
||||
|
||||
// Separable Gaussian / Box blur
|
||||
int samples = int(ceil(radius));
|
||||
|
||||
if (samples == 0) {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
return;
|
||||
}
|
||||
|
||||
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||
|
||||
vec4 color = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
for (int i = -samples; i <= samples; i++) {
|
||||
vec2 offset = dir * float(i) * texelSize;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float weight;
|
||||
if (u_int0 == BLUR_GAUSSIAN) {
|
||||
weight = gaussian(float(i), sigma);
|
||||
} else {
|
||||
// BLUR_BOX
|
||||
weight = 1.0;
|
||||
}
|
||||
|
||||
color += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
|
||||
fragColor0 = color / totalWeight;
|
||||
}
|
||||
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
@@ -0,0 +1,19 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
layout(location = 1) out vec4 fragColor1;
|
||||
layout(location = 2) out vec4 fragColor2;
|
||||
layout(location = 3) out vec4 fragColor3;
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
// Output each channel as grayscale to separate render targets
|
||||
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||
}
|
||||
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
@@ -0,0 +1,71 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
// Levels Adjustment
|
||||
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||
// u_float0: input black (0-255) default: 0
|
||||
// u_float1: input white (0-255) default: 255
|
||||
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||
// u_float3: output black (0-255) default: 0
|
||||
// u_float4: output white (0-255) default: 255
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform int u_int0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, vec3(1.0 / gamma));
|
||||
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||
float inRange = max(inWhite - inBlack, 0.0001);
|
||||
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||
result = pow(result, 1.0 / gamma);
|
||||
result = mix(outBlack, outWhite, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 texColor = texture(u_image0, v_texCoord);
|
||||
vec3 color = texColor.rgb;
|
||||
|
||||
float inBlack = u_float0 / 255.0;
|
||||
float inWhite = u_float1 / 255.0;
|
||||
float gamma = u_float2;
|
||||
float outBlack = u_float3 / 255.0;
|
||||
float outWhite = u_float4 / 255.0;
|
||||
|
||||
vec3 result;
|
||||
|
||||
if (u_int0 == 0) {
|
||||
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 1) {
|
||||
result = color;
|
||||
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 2) {
|
||||
result = color;
|
||||
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else if (u_int0 == 3) {
|
||||
result = color;
|
||||
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||
}
|
||||
else {
|
||||
result = color;
|
||||
}
|
||||
|
||||
fragColor = vec4(result, texColor.a);
|
||||
}
|
||||
28
blueprints/.glsl/README.md
Normal file
28
blueprints/.glsl/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# GLSL Shader Sources
|
||||
|
||||
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||
|
||||
## File Naming Convention
|
||||
|
||||
`{Blueprint_Name}_{node_id}.frag`
|
||||
|
||||
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||
- **node_id**: The GLSLShader node ID within the subgraph
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Extract shaders from blueprint JSONs to this folder
|
||||
python update_blueprints.py extract
|
||||
|
||||
# Patch edited shaders back into blueprint JSONs
|
||||
python update_blueprints.py patch
|
||||
```
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run `extract` to pull current shaders from JSONs
|
||||
2. Edit `.frag` files
|
||||
3. Run `patch` to update the blueprint JSONs
|
||||
4. Test
|
||||
5. Commit both `.frag` files and updated JSONs
|
||||
28
blueprints/.glsl/Sharpen_23.frag
Normal file
28
blueprints/.glsl/Sharpen_23.frag
Normal file
@@ -0,0 +1,28 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
|
||||
// Sample center and neighbors
|
||||
vec4 center = texture(u_image0, v_texCoord);
|
||||
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||
|
||||
// Edge enhancement (Laplacian)
|
||||
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||
|
||||
// Add edges back scaled by strength
|
||||
vec4 sharpened = center + edges * u_float0;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||
}
|
||||
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
@@ -0,0 +1,61 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
float gaussian(float x, float sigma) {
|
||||
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||
}
|
||||
|
||||
float getLuminance(vec3 color) {
|
||||
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 texel = 1.0 / u_resolution;
|
||||
float radius = max(u_float1, 0.5);
|
||||
float amount = u_float0;
|
||||
float threshold = u_float2;
|
||||
|
||||
vec4 original = texture(u_image0, v_texCoord);
|
||||
|
||||
// Gaussian blur for the "unsharp" mask
|
||||
int samples = int(ceil(radius));
|
||||
float sigma = radius / 2.0;
|
||||
|
||||
vec4 blurred = vec4(0.0);
|
||||
float totalWeight = 0.0;
|
||||
|
||||
for (int x = -samples; x <= samples; x++) {
|
||||
for (int y = -samples; y <= samples; y++) {
|
||||
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||
|
||||
float dist = length(vec2(float(x), float(y)));
|
||||
float weight = gaussian(dist, sigma);
|
||||
blurred += sample_color * weight;
|
||||
totalWeight += weight;
|
||||
}
|
||||
}
|
||||
blurred /= totalWeight;
|
||||
|
||||
// Unsharp mask = original - blurred
|
||||
vec3 mask = original.rgb - blurred.rgb;
|
||||
|
||||
// Luminance-based threshold with smooth falloff
|
||||
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||
mask *= thresholdScale;
|
||||
|
||||
// Sharpen: original + mask * amount
|
||||
vec3 sharpened = original.rgb + mask * amount;
|
||||
|
||||
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||
}
|
||||
159
blueprints/.glsl/update_blueprints.py
Normal file
159
blueprints/.glsl/update_blueprints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shader Blueprint Updater
|
||||
|
||||
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||
|
||||
File naming convention:
|
||||
{Blueprint Name}_{node_id}.frag
|
||||
|
||||
Usage:
|
||||
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||
python update_blueprints.py # Same as patch (default)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GLSL_DIR = Path(__file__).parent
|
||||
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||
|
||||
|
||||
def get_blueprint_files():
|
||||
"""Get all blueprint JSON files."""
|
||||
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||
|
||||
|
||||
def sanitize_filename(name):
|
||||
"""Convert blueprint name to safe filename."""
|
||||
return re.sub(r'[^\w\-]', '_', name)
|
||||
|
||||
|
||||
def extract_shaders():
|
||||
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||
extracted = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = json_path.stem
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
# Find GLSLShader nodes in subgraphs
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('type') == 'GLSLShader':
|
||||
node_id = node.get('id')
|
||||
widgets = node.get('widgets_values', [])
|
||||
|
||||
# Find shader code (first string that looks like GLSL)
|
||||
for widget in widgets:
|
||||
if isinstance(widget, str) and widget.startswith('#version'):
|
||||
safe_name = sanitize_filename(blueprint_name)
|
||||
frag_name = f"{safe_name}_{node_id}.frag"
|
||||
frag_path = GLSL_DIR / frag_name
|
||||
|
||||
with open(frag_path, 'w') as f:
|
||||
f.write(widget)
|
||||
|
||||
logger.info(" Extracted: %s", frag_name)
|
||||
extracted += 1
|
||||
break
|
||||
|
||||
logger.info("\nExtracted %d shader(s)", extracted)
|
||||
|
||||
|
||||
def patch_shaders():
|
||||
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||
shader_updates = {}
|
||||
|
||||
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||
parts = frag_path.stem.rsplit('_', 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||
continue
|
||||
|
||||
blueprint_name, node_id_str = parts
|
||||
|
||||
try:
|
||||
node_id = int(node_id_str)
|
||||
except ValueError:
|
||||
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||
continue
|
||||
|
||||
with open(frag_path, 'r') as f:
|
||||
shader_code = f.read()
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
shader_updates[blueprint_name] = []
|
||||
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||
|
||||
# Apply updates to JSON files
|
||||
patched = 0
|
||||
for json_path in get_blueprint_files():
|
||||
blueprint_name = sanitize_filename(json_path.stem)
|
||||
|
||||
if blueprint_name not in shader_updates:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(json_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error("Error reading %s: %s", json_path.name, e)
|
||||
continue
|
||||
|
||||
modified = False
|
||||
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||
# Find the node and update
|
||||
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||
for node in subgraph.get('nodes', []):
|
||||
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||
widgets = node.get('widgets_values', [])
|
||||
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||
widgets[0] = shader_code
|
||||
modified = True
|
||||
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||
patched += 1
|
||||
|
||||
if modified:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
if patched == 0:
|
||||
logger.info("No changes to apply.")
|
||||
else:
|
||||
logger.info("\nPatched %d shader(s)", patched)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
command = "patch"
|
||||
else:
|
||||
command = sys.argv[1].lower()
|
||||
|
||||
if command == "extract":
|
||||
logger.info("Extracting shaders from blueprints...")
|
||||
extract_shaders()
|
||||
elif command in ("patch", "update", "apply"):
|
||||
logger.info("Patching shaders into blueprints...")
|
||||
patch_shaders()
|
||||
else:
|
||||
logger.info(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
blueprints/Brightness and Contrast.json
Normal file
1
blueprints/Brightness and Contrast.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Chromatic Aberration.json
Normal file
1
blueprints/Chromatic Aberration.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Adjustment.json
Normal file
1
blueprints/Color Adjustment.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Edge-Preserving Blur.json
Normal file
1
blueprints/Edge-Preserving Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Film Grain.json
Normal file
1
blueprints/Film Grain.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Glow.json
Normal file
1
blueprints/Glow.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Hue and Saturation.json
Normal file
1
blueprints/Hue and Saturation.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Blur.json
Normal file
1
blueprints/Image Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Channels.json
Normal file
1
blueprints/Image Channels.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}}]}}
|
||||
1
blueprints/Image Levels.json
Normal file
1
blueprints/Image Levels.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Sharpen.json
Normal file
1
blueprints/Sharpen.json
Normal file
@@ -0,0 +1 @@
|
||||
{"revision":0,"last_node_id":25,"last_link_id":0,"nodes":[{"id":25,"type":"621ba4e2-22a8-482d-a369-023753198b7b","pos":[4610,-790],"size":[230,58],"flags":{},"order":4,"mode":0,"inputs":[{"label":"image","localized_name":"images.image0","name":"images.image0","type":"IMAGE","link":null}],"outputs":[{"label":"IMAGE","localized_name":"IMAGE0","name":"IMAGE0","type":"IMAGE","links":[]}],"title":"Sharpen","properties":{"proxyWidgets":[["24","value"]]},"widgets_values":[]}],"links":[],"version":0.4,"definitions":{"subgraphs":[{"id":"621ba4e2-22a8-482d-a369-023753198b7b","version":1,"state":{"lastGroupId":0,"lastNodeId":24,"lastLinkId":36,"lastRerouteId":0},"revision":0,"config":{},"name":"Sharpen","inputNode":{"id":-10,"bounding":[4090,-825,120,60]},"outputNode":{"id":-20,"bounding":[5150,-825,120,60]},"inputs":[{"id":"37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7","name":"images.image0","type":"IMAGE","linkIds":[34],"localized_name":"images.image0","label":"image","pos":[4190,-805]}],"outputs":[{"id":"e9182b3f-635c-4cd4-a152-4b4be17ae4b9","name":"IMAGE0","type":"IMAGE","linkIds":[35],"localized_name":"IMAGE0","label":"IMAGE","pos":[5170,-805]}],"widgets":[],"nodes":[{"id":24,"type":"PrimitiveFloat","pos":[4280,-1240],"size":[270,58],"flags":{},"order":0,"mode":0,"inputs":[{"label":"strength","localized_name":"value","name":"value","type":"FLOAT","widget":{"name":"value"},"link":null}],"outputs":[{"localized_name":"FLOAT","name":"FLOAT","type":"FLOAT","links":[36]}],"properties":{"Node name for S&R":"PrimitiveFloat","min":0,"max":3,"precision":2,"step":0.05},"widgets_values":[0.5]},{"id":23,"type":"GLSLShader","pos":[4570,-1240],"size":[370,192],"flags":{},"order":1,"mode":0,"inputs":[{"label":"image0","localized_name":"images.image0","name":"images.image0","type":"IMAGE","link":34},{"label":"image1","localized_name":"images.image1","name":"images.image1","shape":7,"type":"IMAGE","link":null},{"label":"u_float0","localized_name":"floats.u_float0","name":"floats.u_float0","shape":7,"type":"FLOAT","link":36},{"label":"u_float1","localized_name":"floats.u_float1","name":"floats.u_float1","shape":7,"type":"FLOAT","link":null},{"label":"u_int0","localized_name":"ints.u_int0","name":"ints.u_int0","shape":7,"type":"INT","link":null},{"localized_name":"fragment_shader","name":"fragment_shader","type":"STRING","widget":{"name":"fragment_shader"},"link":null},{"localized_name":"size_mode","name":"size_mode","type":"COMFY_DYNAMICCOMBO_V3","widget":{"name":"size_mode"},"link":null}],"outputs":[{"localized_name":"IMAGE0","name":"IMAGE0","type":"IMAGE","links":[35]},{"localized_name":"IMAGE1","name":"IMAGE1","type":"IMAGE","links":null},{"localized_name":"IMAGE2","name":"IMAGE2","type":"IMAGE","links":null},{"localized_name":"IMAGE3","name":"IMAGE3","type":"IMAGE","links":null}],"properties":{"Node name for S&R":"GLSLShader"},"widgets_values":["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}","from_input"]}],"groups":[],"links":[{"id":36,"origin_id":24,"origin_slot":0,"target_id":23,"target_slot":2,"type":"FLOAT"},{"id":34,"origin_id":-10,"origin_slot":0,"target_id":23,"target_slot":0,"type":"IMAGE"},{"id":35,"origin_id":23,"origin_slot":0,"target_id":-20,"target_slot":0,"type":"IMAGE"}],"extra":{"workflowRendererVersion":"LG"}}]}}
|
||||
1
blueprints/Unsharp Mask.json
Normal file
1
blueprints/Unsharp Mask.json
Normal file
File diff suppressed because one or more lines are too long
@@ -25,11 +25,11 @@ class AudioEncoderModel():
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@@ -257,3 +258,6 @@ elif args.fast == []:
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
|
||||
@@ -47,10 +47,10 @@ class ClipVisionModel():
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -236,6 +236,8 @@ class ComfyNodeABC(ABC):
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
DEV_ONLY: bool
|
||||
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@@ -5,7 +5,7 @@ from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@@ -13,6 +13,9 @@ from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@@ -80,6 +81,7 @@ class SD_X4(LatentFormat):
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
latent_channels = 16
|
||||
spacial_downscale_ratio = 42
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
@@ -102,6 +104,7 @@ class SC_Prior(LatentFormat):
|
||||
]
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
spacial_downscale_ratio = 4
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
@@ -181,6 +184,7 @@ class Flux(SD3):
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
@@ -272,6 +276,7 @@ class Mochi(LatentFormat):
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@@ -515,6 +520,7 @@ class Wan21(LatentFormat):
|
||||
class Wan22(Wan21):
|
||||
latent_channels = 48
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.0119, 0.0103, 0.0046],
|
||||
@@ -592,6 +598,7 @@ class Wan22(Wan21):
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
spacial_downscale_ratio = 32
|
||||
scale_factor = 0.75289
|
||||
|
||||
latent_rgb_factors = [
|
||||
@@ -725,6 +732,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@@ -747,8 +755,13 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
|
||||
1155
comfy/ldm/ace/ace_step15.py
Normal file
1155
comfy/ldm/ace/ace_step15.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@@ -3,7 +3,6 @@ from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
||||
@@ -4,8 +4,6 @@ from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
|
||||
@@ -13,6 +13,7 @@ from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
@@ -334,7 +335,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@@ -462,6 +463,8 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@@ -511,7 +514,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -521,7 +524,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@@ -535,7 +538,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@@ -554,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@@ -562,8 +565,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@@ -835,6 +838,8 @@ class MiniTrainDIT(nn.Module):
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
orig_shape = list(x.shape)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
@@ -873,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@@ -881,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@@ -5,9 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
# Fix import for some custom nodes, TODO: delete eventually.
|
||||
RMSNorm = None
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
|
||||
@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
|
||||
@@ -16,7 +16,6 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -81,7 +80,7 @@ class Flux(nn.Module):
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
|
||||
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
@@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
|
||||
@@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@@ -18,12 +18,12 @@ class CompressedTimestep:
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
@@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
# audio to video cross attention
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
@@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
@@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
@@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import threading
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
|
||||
padding_mode=spatial_padding_mode,
|
||||
groups=groups,
|
||||
)
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if causal:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, self.time_kernel_size - 1, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x), dim=2)
|
||||
else:
|
||||
first_frame_pad = x[:, :, :1, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
tid = threading.get_ident()
|
||||
|
||||
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
|
||||
if cached is None:
|
||||
padding_length = self.time_kernel_size - 1
|
||||
if not causal:
|
||||
padding_length = padding_length // 2
|
||||
if x.shape[2] == 0:
|
||||
return x
|
||||
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
@@ -6,12 +7,35 @@ import math
|
||||
from einops import rearrange
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
current = m.temporal_cache_state.get(tid, (None, False))
|
||||
m.temporal_cache_state[tid] = (current[0], True)
|
||||
|
||||
def split2(tensor, split_point, dim=2):
|
||||
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
||||
|
||||
def add_exchange_cache(dest, cache_in, new_input, dim=2):
|
||||
if dest is not None:
|
||||
if cache_in is not None:
|
||||
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
|
||||
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
|
||||
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
|
||||
lead_in_dest.add_(lead_in_source)
|
||||
body, new_input = split2(new_input, dest.shape[dim], dim)
|
||||
dest.add_(body)
|
||||
return torch_cat_if_needed([cache_in, new_input], dim=dim)
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
||||
@@ -205,7 +229,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
@@ -254,6 +278,22 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
tid = threading.get_ident()
|
||||
for _, module in self.named_modules():
|
||||
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||
# we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -341,18 +381,6 @@ class Decoder(nn.Module):
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
@@ -428,8 +456,9 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
@@ -437,6 +466,7 @@ class Decoder(nn.Module):
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
mark_conv3d_ended(self.conv_in)
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -445,24 +475,12 @@ class Decoder(nn.Module):
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
scaled_timestep = None
|
||||
timestep_shift_scale = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
timestep=scaled_timestep.flatten(),
|
||||
resolution=None,
|
||||
@@ -483,16 +501,62 @@ class Decoder(nn.Module):
|
||||
embedded_timestep.shape[-2],
|
||||
embedded_timestep.shape[-1],
|
||||
)
|
||||
shift, scale = ada_values.unbind(dim=1)
|
||||
sample = sample * (1 + scale) + shift
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
output = []
|
||||
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample)
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
for _, module in self.named_modules():
|
||||
#ComfyUI doesn't thread this kind of stuff today, but just incase
|
||||
#we key on the thread to make it thread safe.
|
||||
tid = threading.get_ident()
|
||||
if hasattr(module, "temporal_cache_state"):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
"""
|
||||
@@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||
tid = threading.get_ident()
|
||||
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
|
||||
y = self.conv(x, causal=causal)
|
||||
y = rearrange(
|
||||
y,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
|
||||
y = y[:, :, 1:, :, :]
|
||||
drop_first_conv = False
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
@@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
)
|
||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2:
|
||||
x = x[:, :, 1:, :, :]
|
||||
if self.residual:
|
||||
x = x + x_in
|
||||
return x
|
||||
drop_first_res = False
|
||||
|
||||
if y.shape[2] == 0:
|
||||
y = None
|
||||
|
||||
cached = add_exchange_cache(y, cached, x_in, dim=2)
|
||||
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
|
||||
|
||||
else:
|
||||
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
|
||||
|
||||
return y
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
@@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
def _feed_spatial_noise(
|
||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
tid = threading.get_ident()
|
||||
cached = self.temporal_cache_state.get(tid, None)
|
||||
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
|
||||
self.temporal_cache_state[tid] = cached
|
||||
|
||||
return output_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
|
||||
@@ -451,6 +451,7 @@ class NextDiT(nn.Module):
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
elif len(xl) == 1:
|
||||
return xl[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,196 @@ import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
from .model import QwenImageTransformerBlock
|
||||
|
||||
|
||||
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.has_before_proj = has_before_proj
|
||||
if has_before_proj:
|
||||
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_in_features=132,
|
||||
inner_dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.main_model_double = main_model_double
|
||||
self.injection_layers = tuple(injection_layers)
|
||||
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||
# to the reference Gen2/VideoX implementation around strength=1.
|
||||
self.hint_scale = 1.0
|
||||
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||
|
||||
self.control_blocks = torch.nn.ModuleList([])
|
||||
for i in range(num_control_blocks):
|
||||
self.control_blocks.append(
|
||||
QwenImageFunControlBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
has_before_proj=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_hint_tokens(self, hint):
|
||||
if hint is None:
|
||||
return None
|
||||
if hint.ndim == 4:
|
||||
hint = hint.unsqueeze(2)
|
||||
|
||||
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||
if hint.shape[1] == 16 and expected_c == 33:
|
||||
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||
zeros_inpaint = torch.zeros_like(hint)
|
||||
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||
|
||||
bs, c, t, h, w = hint.shape
|
||||
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
orig_shape[0],
|
||||
orig_shape[1],
|
||||
orig_shape[-3],
|
||||
orig_shape[-2] // 2,
|
||||
2,
|
||||
orig_shape[-1] // 2,
|
||||
2,
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(
|
||||
bs,
|
||||
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||
c * 4,
|
||||
)
|
||||
|
||||
expected_in = self.control_img_in.weight.shape[1]
|
||||
cur_in = hidden_states.shape[-1]
|
||||
if cur_in < expected_in:
|
||||
pad = torch.zeros(
|
||||
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||
elif cur_in > expected_in:
|
||||
hidden_states = hidden_states[:, :, :expected_in]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
base_model=None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_model is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||
encoder_hidden_states_mask = None
|
||||
|
||||
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||
hint_tokens = self._process_hint_tokens(hint)
|
||||
if hint_tokens is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||
|
||||
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||
hint_tokens = hint_tokens[:, :max_tokens]
|
||||
hidden_states = hidden_states[:, :max_tokens]
|
||||
img_ids = img_ids[:, :max_tokens]
|
||||
|
||||
txt_start = round(
|
||||
max(
|
||||
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
)
|
||||
)
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
|
||||
hidden_states = base_model.img_in(hidden_states)
|
||||
encoder_hidden_states = base_model.txt_norm(context)
|
||||
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
base_model.time_text_embed(timesteps, hidden_states)
|
||||
if guidance is None
|
||||
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||
)
|
||||
|
||||
c = self.control_img_in(hint_tokens)
|
||||
|
||||
for i, block in enumerate(self.control_blocks):
|
||||
if i == 0:
|
||||
c_in = block.before_proj(c) + hidden_states
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c, dim=0))
|
||||
c_in = all_c.pop(-1)
|
||||
|
||||
encoder_hidden_states, c_out = block(
|
||||
hidden_states=c_in,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||
all_c += [c_skip, c_out]
|
||||
c = torch.stack(all_c, dim=0)
|
||||
|
||||
hints = torch.unbind(c, dim=0)[:-1]
|
||||
|
||||
controlnet_block_samples = [None] * self.main_model_double
|
||||
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||
|
||||
return {"input": controlnet_block_samples}
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
|
||||
@@ -170,8 +170,14 @@ class Attention(nn.Module):
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
if encoder_hidden_states_mask is not None:
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
attn_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
@@ -430,6 +436,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
|
||||
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
self._padding = 2 * self.padding[0]
|
||||
self.padding = (0, self.padding[1], self.padding[2])
|
||||
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
if cache_x is None and x.shape[2] == 1:
|
||||
#Fast path - the op will pad for use by truncating the weight
|
||||
#and save math on a pile of zeros.
|
||||
return super().forward(x, autopad="causal_zero")
|
||||
|
||||
if self._padding > 0:
|
||||
padding_needed = self._padding
|
||||
if cache_x is not None:
|
||||
cache_x = cache_x.to(x.device)
|
||||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||||
padding_shape = list(x.shape)
|
||||
padding_shape[2] = padding_needed
|
||||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
@@ -472,10 +479,12 @@ class WanVAE(nn.Module):
|
||||
|
||||
def encode(self, x):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -495,10 +504,11 @@ class WanVAE(nn.Module):
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
|
||||
@@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
@@ -331,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep15):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import comfy.utils
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
|
||||
81
comfy/memory_management.py
Normal file
81
comfy/memory_management.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
def element_size(self):
|
||||
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
|
||||
return info.bits // 8
|
||||
|
||||
def numel(self):
|
||||
return math.prod(self.shape)
|
||||
|
||||
def tensors_to_geometries(tensors, dtype=None):
|
||||
geometries = []
|
||||
for t in tensors:
|
||||
if t is None or isinstance(t, QuantizedTensor):
|
||||
geometries.append(t)
|
||||
continue
|
||||
tdtype = t.dtype
|
||||
if hasattr(t, "_model_dtype"):
|
||||
tdtype = t._model_dtype
|
||||
if dtype is not None:
|
||||
tdtype = dtype
|
||||
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
|
||||
return geometries
|
||||
|
||||
def vram_aligned_size(tensor):
|
||||
if isinstance(tensor, list):
|
||||
return sum([vram_aligned_size(t) for t in tensor])
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||
|
||||
if tensor is None:
|
||||
return 0
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
aligment_req = 1024
|
||||
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||
|
||||
def interpret_gathered_like(tensors, gathered):
|
||||
offset = 0
|
||||
dest_views = []
|
||||
|
||||
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||
|
||||
for tensor in tensors:
|
||||
|
||||
if tensor is None:
|
||||
dest_views.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||
else:
|
||||
templates = { "data": tensor }
|
||||
|
||||
actuals = {}
|
||||
for attr, template in templates.items():
|
||||
size = template.numel() * template.element_size()
|
||||
if offset + size > gathered.numel():
|
||||
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||
offset += vram_aligned_size(template)
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||
else:
|
||||
dest_views.append(actuals["data"])
|
||||
|
||||
return dest_views
|
||||
|
||||
aimdo_allocator = None
|
||||
@@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -146,6 +147,8 @@ class BaseModel(torch.nn.Module):
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@@ -299,7 +302,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
def load_model_weights(self, sd, unet_prefix="", assign=False):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
@@ -307,7 +310,7 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@@ -322,7 +325,7 @@ class BaseModel(torch.nn.Module):
|
||||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
extra_sds = []
|
||||
if clip_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||
@@ -330,10 +333,7 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||
if clip_vision_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
@@ -776,8 +776,8 @@ class StableAudio1(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||
for k in d:
|
||||
s = d[k]
|
||||
@@ -1160,12 +1160,16 @@ class Anima(BaseModel):
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
@@ -1541,6 +1545,49 @@ class ACEStep(BaseModel):
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class ACEStep15(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
device = kwargs["device"]
|
||||
noise = kwargs["noise"]
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if torch.count_nonzero(cross_attn) == 0:
|
||||
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
if cross_attn is not None:
|
||||
out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||
|
||||
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
|
||||
if refer_audio is None or len(refer_audio) == 0:
|
||||
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
pass_audio_codes = True
|
||||
else:
|
||||
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||
out['is_covers'] = comfy.conds.CONDConstant(True)
|
||||
pass_audio_codes = False
|
||||
|
||||
if pass_audio_codes:
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
refer_audio = refer_audio[:, :, :750]
|
||||
else:
|
||||
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
|
||||
if refer_audio.shape[2] < noise.shape[2]:
|
||||
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||
@@ -1578,6 +1625,9 @@ class QwenImage(BaseModel):
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||
for x in suffix_list:
|
||||
if "{}{}{}".format(prefix, main, x) in keys:
|
||||
return True
|
||||
return False
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "flux2"
|
||||
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["out_channels"] = 64
|
||||
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
@@ -444,6 +452,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||
dit_config["z_image_modulation"] = True
|
||||
dit_config["time_scale"] = 1000.0
|
||||
try:
|
||||
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
|
||||
except Exception:
|
||||
pass
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
||||
@@ -651,6 +663,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["audio_model"] = "ace1.5"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
@@ -20,12 +20,20 @@ import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
import comfy_aimdo.torch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@@ -47,6 +55,11 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
@@ -578,9 +591,15 @@ WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
import comfy.windows
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||
def get_free_ram():
|
||||
return comfy.windows.get_free_ram()
|
||||
else:
|
||||
def get_free_ram():
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@@ -592,7 +611,7 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@@ -607,15 +626,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = None
|
||||
memory_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
memory_to_free = 0
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if ram_to_free > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@@ -650,7 +677,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
models_to_load = []
|
||||
|
||||
free_for_dynamic=True
|
||||
for x in models:
|
||||
if not x.is_dynamic():
|
||||
free_for_dynamic = False
|
||||
loaded_model = LoadedModel(x)
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
@@ -676,19 +706,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
@@ -732,6 +768,9 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@@ -749,6 +788,11 @@ def cleanup_models_gc():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
def archive_model_dtypes(model):
|
||||
for name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
@@ -792,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
@@ -1051,6 +1095,51 @@ def current_stream(device):
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = nullcontext()
|
||||
|
||||
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None or cast_buffer.numel() < size:
|
||||
if ref is LARGEST_CASTED_WEIGHT[0]:
|
||||
#If there is one giant weight we do not want both streams to
|
||||
#allocate a buffer for it. It's up to the caster to get the other
|
||||
#offload stream in this corner case
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
if size > LARGEST_CASTED_WEIGHT[1]:
|
||||
LARGEST_CASTED_WEIGHT = (ref, size)
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS == 0:
|
||||
@@ -1093,7 +1182,61 @@ def sync_stream(device, stream):
|
||||
return
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
wf_context = nullcontext()
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
if hasattr(weight, "_v"):
|
||||
#Unexpected usage patterns. There is no reason these don't work but they
|
||||
#have no testing and no callers do this.
|
||||
assert r is None
|
||||
assert stream is None
|
||||
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
return v_tensor.to(dtype=dtype)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||
#that would happen in regular flow to make offload deterministic.
|
||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||
weight = cast_buffer
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
return r
|
||||
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@@ -1112,10 +1255,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
@@ -1135,14 +1280,14 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
@@ -1546,6 +1691,12 @@ def lora_compute_dtype(device):
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@@ -1557,6 +1708,7 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
@@ -1568,9 +1720,6 @@ def debug_memory_summary():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
@@ -38,19 +37,7 @@ from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
@@ -123,6 +110,10 @@ def move_weight_functions(m, device):
|
||||
memory += f.move_to(device=device)
|
||||
return memory
|
||||
|
||||
def string_to_seed(data):
|
||||
logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils")
|
||||
return comfy.utils.string_to_seed(data)
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
@@ -169,6 +160,11 @@ def get_key_weight(model, key):
|
||||
|
||||
return weight, set_func, convert_func
|
||||
|
||||
def key_param_name_to_key(key, param):
|
||||
if len(key) == 0:
|
||||
return param
|
||||
return "{}.{}".format(key, param)
|
||||
|
||||
class AutoPatcherEjector:
|
||||
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
||||
self.model = model
|
||||
@@ -212,6 +208,27 @@ class MemoryCounter:
|
||||
def decrement(self, used: int):
|
||||
self.value -= used
|
||||
|
||||
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
|
||||
|
||||
class LazyCastingParam(torch.nn.Parameter):
|
||||
def __new__(cls, model, key, tensor):
|
||||
return super().__new__(cls, tensor)
|
||||
|
||||
def __init__(self, model, key, tensor):
|
||||
self.model = model
|
||||
self.key = key
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return CustomTorchDevice
|
||||
|
||||
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
|
||||
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
|
||||
#all weights getting garbage collected per-weight
|
||||
def to(self, *args, **kwargs):
|
||||
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@@ -269,6 +286,9 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def is_dynamic(self):
|
||||
return False
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@@ -284,6 +304,9 @@ class ModelPatcher:
|
||||
def lowvram_patch_counter(self):
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
@@ -293,7 +316,7 @@ class ModelPatcher:
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
@@ -611,14 +634,14 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if key not in self.patches:
|
||||
return weight
|
||||
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
if key not in self.backup and not return_weight:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
@@ -631,13 +654,15 @@ class ModelPatcher:
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
if return_weight:
|
||||
return out_weight
|
||||
elif inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
@@ -654,18 +679,19 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
params = []
|
||||
skip = False
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
params.append(name)
|
||||
default = False
|
||||
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||
for name, param in m.named_parameters(recurse=True):
|
||||
if name not in params:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
default = True # default random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
if default and default_device is not None:
|
||||
for param in params.values():
|
||||
param.data = param.data.to(device=default_device)
|
||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
@@ -681,7 +707,8 @@ class ModelPatcher:
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@@ -773,7 +800,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
@@ -789,7 +816,7 @@ class ModelPatcher:
|
||||
n = x[1]
|
||||
params = x[3]
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||
|
||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||
if lowvram_counter > 0:
|
||||
@@ -895,7 +922,7 @@ class ModelPatcher:
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
move_weight = True
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if not lowvram_possible:
|
||||
@@ -946,7 +973,7 @@ class ModelPatcher:
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||
|
||||
|
||||
self.model.model_lowvram = True
|
||||
@@ -984,6 +1011,9 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
@@ -1317,10 +1347,10 @@ class ModelPatcher:
|
||||
key, original_weights=original_weights)
|
||||
del original_weights[key]
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
||||
set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
target_device = self.offload_device
|
||||
@@ -1355,7 +1385,256 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model.diffusion_model.state_dict()
|
||||
for k, v in unet_state_dict.items():
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
except:
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False):
|
||||
if load_device is not None and comfy.model_management.is_device_cpu(load_device):
|
||||
#reroute to default MP for CPUs
|
||||
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||
#and non-dynamic patchers.
|
||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||
del self.model.model_loaded_weight_memory
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
assert load_device is not None
|
||||
|
||||
def is_dynamic(self):
|
||||
return True
|
||||
|
||||
def _vbar_get(self, create=False):
|
||||
if self.load_device == torch.device("cpu"):
|
||||
return None
|
||||
vbar = self.model.dynamic_vbars.get(self.load_device, None)
|
||||
if create and vbar is None:
|
||||
# x10. We dont know what model defined type casts we have in the vbar, but virtual address
|
||||
# space is pretty free. This will cover someone casting an entire model from FP4 to FP32
|
||||
# with some left over.
|
||||
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index)
|
||||
self.model.dynamic_vbars[self.load_device] = vbar
|
||||
return vbar
|
||||
|
||||
def loaded_size(self):
|
||||
vbar = self._vbar_get()
|
||||
if vbar is None:
|
||||
return 0
|
||||
return vbar.loaded_size()
|
||||
|
||||
def get_free_memory(self, device):
|
||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||
#would all be avaiable for inference use.
|
||||
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||
|
||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
|
||||
|
||||
def unpin_weight(self, key):
|
||||
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self.partially_unload_ram(1e32)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
#Pad this significantly. We are trying to get away from precise estimates. This
|
||||
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
|
||||
#doesn't need to be forced at this stage. The only thing you could do would be patch
|
||||
#it all on CPU which consumes huge RAM.
|
||||
assert not force_patch_weights
|
||||
|
||||
#Full load doesn't make sense as we dont actually have any loader capability here and
|
||||
#now.
|
||||
assert not full_load
|
||||
|
||||
assert device_to == self.load_device
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
|
||||
vbar = self._vbar_get(create=True)
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||
#with a high layer rate (e.g. autoregressive LLMs).
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||
loading.sort(reverse=True)
|
||||
|
||||
for x in loading:
|
||||
_, _, _, n, m, params = x
|
||||
|
||||
def set_dirty(item, dirty):
|
||||
if dirty or not hasattr(item, "_v_signature"):
|
||||
item._v_signature = None
|
||||
|
||||
def setup_param(self, m, n, param_key):
|
||||
nonlocal num_patches
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
|
||||
weight_function = []
|
||||
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
if key in self.patches:
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
|
||||
if key in self.weight_wrapper_patches:
|
||||
weight_function.extend(self.weight_wrapper_patches[key])
|
||||
setattr(m, param_key + "_function", weight_function)
|
||||
geometry = weight
|
||||
if not isinstance(weight, QuantizedTensor):
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||
weight._model_dtype = model_dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
return comfy.memory_management.vram_aligned_size(geometry)
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
m.pin_failed = False
|
||||
m.seed_key = n
|
||||
set_dirty(m, dirty)
|
||||
|
||||
v_weight_size = 0
|
||||
v_weight_size += setup_param(self, m, n, "weight")
|
||||
v_weight_size += setup_param(self, m, n, "bias")
|
||||
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
key = key_param_name_to_key(n, param)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
#These are all super dangerous. Who knows what the custom nodes actually do here...
|
||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
assert self.load_device != torch.device("cpu")
|
||||
|
||||
vbar = self._vbar_get()
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
#the control of proper model_managment. If you are a custom node author reading
|
||||
#this, the correct pattern is to call load_models_gpu() to get a proper
|
||||
#managed load of your model.
|
||||
assert not load_weights
|
||||
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
super().unpatch_model(device_to=None, unpatch_weights=False)
|
||||
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None, 1e32)
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
|
||||
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=False)
|
||||
self.patch_model(load_weights=False)
|
||||
|
||||
try:
|
||||
self.load(device_to, dirty=dirty)
|
||||
except Exception as e:
|
||||
self.detach()
|
||||
raise e
|
||||
#ModelPatcher::partially_load returns a number on what got loaded but
|
||||
#nothing in core uses this and we have no data in the Dynamic world. Hit
|
||||
#the custom node devs with a None rather than a 0 that would mislead any
|
||||
#logic they might have.
|
||||
return None
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
assert False #Should be unreachable - we dont ever cache in the new implementation
|
||||
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
return
|
||||
|
||||
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
pass
|
||||
|
||||
CoreModelPatcher = ModelPatcher
|
||||
|
||||
240
comfy/ops.py
240
comfy/ops.py
@@ -19,10 +19,16 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.pinned_memory
|
||||
import comfy.utils
|
||||
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy_aimdo.torch
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
@@ -48,6 +54,8 @@ try:
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
@@ -72,7 +80,122 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
|
||||
if not resident:
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
if pin is not None:
|
||||
xfer_source = [ pin ]
|
||||
|
||||
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
else:
|
||||
pin = None
|
||||
|
||||
if pin is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
tensor = tensor.dequantize()
|
||||
return tensor
|
||||
|
||||
if orig.dtype != dtype or len(fns) > 0:
|
||||
x = to_dequant(x, dtype)
|
||||
if not resident and lowvram_fn is not None:
|
||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@@ -87,22 +210,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
bias = None
|
||||
weight = None
|
||||
|
||||
if offload_stream is not None and not args.cuda_malloc:
|
||||
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
||||
if cast_buffer is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
@@ -110,6 +249,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
weight_a = weight
|
||||
|
||||
if s.bias is not None:
|
||||
bias = bias.to(dtype=bias_dtype)
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
@@ -131,14 +271,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
device=None
|
||||
#FIXME: This is not good RTTI
|
||||
if not isinstance(weight_a, torch.Tensor):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
device = weight_a
|
||||
if os is None:
|
||||
return
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
if device is None:
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
os.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
@@ -149,6 +295,57 @@ class CastWeightBiasOp:
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
# Issue is with `torch.empty` still reserving the full memory for the layer.
|
||||
# Windows doesn't over-commit memory so without this, We are momentarily commit
|
||||
# charged for the weight even though we might zero-copy it when we load the
|
||||
# state dict. If the commit charge exceeds the ceiling we can destabilize the
|
||||
# system.
|
||||
torch.nn.Module.__init__(self)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.comfy_need_lazy_init_bias=bias
|
||||
self.weight_comfy_model_dtype = dtype
|
||||
self.bias_comfy_model_dtype = dtype
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@@ -203,7 +400,9 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
|
||||
if autopad == "causal_zero":
|
||||
weight = weight[:, :, -input.shape[2]:, :, :]
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
@@ -212,15 +411,15 @@ class disable_weight_init:
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
def forward_comfy_cast_weights(self, input, autopad=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
x = self._conv_forward(input, weight, bias, autopad=autopad)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@@ -653,8 +852,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@@ -664,6 +863,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
input_shape = input.shape
|
||||
reshaped_3d = False
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
@@ -682,7 +883,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
29
comfy/pinned_memory.py
Normal file
29
comfy/pinned_memory.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def get_pin(module):
|
||||
return getattr(module, "_pin", None)
|
||||
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
del module._pin
|
||||
return size
|
||||
@@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if torch.count_nonzero(latent_image) == 0:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if downscale_ratio_spacial is not None:
|
||||
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
|
||||
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
|
||||
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
|
||||
|
||||
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
||||
latent_image = latent_image.unsqueeze(2)
|
||||
return latent_image
|
||||
|
||||
@@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||
memory_required = 1e20
|
||||
minimum_memory_required = None
|
||||
else:
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
@@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
free_memory = model.current_patcher.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
|
||||
194
comfy/sd.py
194
comfy/sd.py
@@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
import comfy.weight_adapter
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
@@ -58,6 +59,7 @@ import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -101,6 +103,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
|
||||
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
"""
|
||||
Load LoRA in bypass mode without modifying base model weights.
|
||||
|
||||
Instead of patching weights, this injects the LoRA computation into the
|
||||
forward pass: output = base_forward(x) + lora_path(x)
|
||||
|
||||
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
|
||||
|
||||
This is useful for training and when model weights are offloaded.
|
||||
"""
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
|
||||
|
||||
lora = comfy.lora_convert.convert_lora(lora)
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
|
||||
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
|
||||
|
||||
# Separate adapters (for bypass) from other patches (for regular patching)
|
||||
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
|
||||
regular_patches = {} # diff, set, bias patches -> regular weight patching
|
||||
|
||||
for key, patch_data in loaded.items():
|
||||
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
|
||||
bypass_patches[key] = patch_data
|
||||
else:
|
||||
regular_patches[key] = patch_data
|
||||
|
||||
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
|
||||
|
||||
k = set()
|
||||
k1 = set()
|
||||
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
|
||||
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
|
||||
if regular_patches:
|
||||
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
|
||||
k.update(patched_keys)
|
||||
|
||||
# Apply adapter patches via bypass injection
|
||||
manager = comfy.weight_adapter.BypassInjectionManager()
|
||||
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
|
||||
|
||||
for key, adapter in bypass_patches.items():
|
||||
if key in model_sd_keys:
|
||||
manager.add_adapter(key, adapter, strength=strength_model)
|
||||
k.add(key)
|
||||
else:
|
||||
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
|
||||
|
||||
injections = manager.create_injections(new_modelpatcher.model)
|
||||
|
||||
if manager.get_hook_count() > 0:
|
||||
new_modelpatcher.set_injections("bypass_lora", injections)
|
||||
else:
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
|
||||
# Apply regular patches to clip
|
||||
if regular_patches:
|
||||
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
|
||||
k1.update(patched_keys)
|
||||
|
||||
# Apply adapter patches via bypass injection
|
||||
clip_manager = comfy.weight_adapter.BypassInjectionManager()
|
||||
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
|
||||
|
||||
for key, adapter in bypass_patches.items():
|
||||
if key in clip_sd_keys:
|
||||
clip_manager.add_adapter(key, adapter, strength=strength_clip)
|
||||
k1.add(key)
|
||||
|
||||
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
|
||||
if clip_manager.get_hook_count() > 0:
|
||||
new_clip.patcher.set_injections("bypass_lora", clip_injections)
|
||||
else:
|
||||
new_clip = None
|
||||
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
patch_data = loaded[x]
|
||||
patch_type = type(patch_data).__name__
|
||||
if isinstance(patch_data, tuple):
|
||||
patch_type = f"tuple({patch_data[0]})"
|
||||
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
|
||||
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||
if no_init:
|
||||
@@ -128,8 +229,10 @@ class CLIP:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
@@ -289,8 +392,18 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
else:
|
||||
can_assign = self.patcher.is_dynamic()
|
||||
self.cond_stage_model.can_assign_sd = can_assign
|
||||
|
||||
# The CLIP models are a pretty complex web of wrappers and its
|
||||
# a bit of an API change to plumb this all the way through.
|
||||
# So spray paint the model with this flag that the loading
|
||||
# nn.Module can then inspect for itself.
|
||||
for m in self.cond_stage_model.modules():
|
||||
m.can_assign_sd = can_assign
|
||||
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
@@ -340,6 +453,8 @@ class VAE:
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
|
||||
self.audio_sample_rate = 44100
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@@ -437,14 +552,27 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
self.first_stage_model = AudioOobleckVAE()
|
||||
config = {}
|
||||
param_key = None
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
if "decoder.layers.2.layers.1.weight_v" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.weight_v"
|
||||
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1"
|
||||
if param_key is not None:
|
||||
if sd[param_key].shape[-1] == 12:
|
||||
config["strides"] = [2, 4, 4, 6, 10]
|
||||
self.audio_sample_rate = 48000
|
||||
self.upscale_ratio = 1920
|
||||
self.downscale_ratio = 1920
|
||||
|
||||
self.first_stage_model = AudioOobleckVAE(**config)
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
@@ -665,13 +793,6 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
@@ -680,9 +801,21 @@ class VAE:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
mp = comfy.model_patcher.CoreModelPatcher
|
||||
if self.disable_offload:
|
||||
mp = comfy.model_patcher.ModelPatcher
|
||||
self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
@@ -738,7 +871,7 @@ class VAE:
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
@@ -797,7 +930,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
@@ -842,7 +975,7 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
@@ -871,7 +1004,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
@@ -1309,6 +1442,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_data_jina = clip_data[0]
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
elif clip_type == CLIPType.ACE:
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
if TEModel.QWEN3_4B in te_models:
|
||||
model_type = "qwen3_4b"
|
||||
else:
|
||||
model_type = "qwen3_2b"
|
||||
clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@@ -1332,7 +1473,7 @@ def load_gligen(ckpt_path):
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def model_detection_error_hint(path, state_dict):
|
||||
filename = os.path.basename(path)
|
||||
@@ -1420,7 +1561,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
@@ -1463,7 +1605,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
@@ -1555,13 +1696,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
if not model_management.is_device_cpu(offload_device):
|
||||
model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
return model_patcher
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
@@ -1592,9 +1734,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
model_management.load_models_gpu(load_models)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
|
||||
@@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
if isinstance(self.layer, list) or self.layer == "all":
|
||||
pass
|
||||
elif isinstance(layer_idx, list):
|
||||
self.layer = layer_idx
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
@@ -169,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
pad_token = self.special_tokens.get("pad", -1)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
cmp_token = pad_token
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
@@ -184,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
other_embeds = []
|
||||
eos = False
|
||||
index = 0
|
||||
left_pad = False
|
||||
for y in x:
|
||||
if isinstance(y, numbers.Integral):
|
||||
if eos:
|
||||
token = int(y)
|
||||
if index == 0 and token == pad_token:
|
||||
left_pad = True
|
||||
|
||||
if eos or (left_pad and token == pad_token):
|
||||
attention_mask.append(0)
|
||||
else:
|
||||
attention_mask.append(1)
|
||||
token = int(y)
|
||||
left_pad = False
|
||||
|
||||
tokens_temp += [token]
|
||||
if not eos and token == cmp_token:
|
||||
if not eos and token == cmp_token and not left_pad:
|
||||
if end_token is None:
|
||||
attention_mask[-1] = 0
|
||||
eos = True
|
||||
@@ -297,7 +306,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
@@ -466,7 +475,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, start_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
@@ -479,8 +488,15 @@ class SDTokenizer:
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
if len(empty) > 0:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = start_token
|
||||
if start_token is None:
|
||||
logging.warning("WARNING: There's something wrong with your tokenizers.'")
|
||||
|
||||
if end_token is not None:
|
||||
self.end_token = end_token
|
||||
else:
|
||||
@@ -488,7 +504,7 @@ class SDTokenizer:
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.start_token = start_token
|
||||
if end_token is not None:
|
||||
self.end_token = end_token
|
||||
else:
|
||||
|
||||
@@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -709,6 +710,15 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith("_norm.scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@@ -771,10 +781,24 @@ class Flux2(Flux):
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
if len(detect) > 0:
|
||||
detect["model_type"] = "qwen3_4b"
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect))
|
||||
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
|
||||
if len(detect) > 0:
|
||||
detect["model_type"] = "qwen3_8b"
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect))
|
||||
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref))
|
||||
if len(detect) > 0:
|
||||
if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict:
|
||||
detect["pruned"] = True
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect))
|
||||
|
||||
return None
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@@ -883,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
@@ -978,7 +1004,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
@@ -1008,11 +1034,7 @@ class Anima(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Anima(self, device=device)
|
||||
@@ -1023,6 +1045,12 @@ class Anima(supported_models_base.BASE):
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
|
||||
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
if dtype is torch.float16:
|
||||
self.memory_usage_factor *= 1.4
|
||||
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
@@ -1079,7 +1107,7 @@ class ZImage(Lumina2):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
if comfy.model_management.extended_fp16_support():
|
||||
if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False):
|
||||
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||
|
||||
@@ -1247,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
@@ -1324,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Chroma(self, device=device)
|
||||
@@ -1582,6 +1627,46 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
class ACEStep15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace1.5",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.ACEAudio15
|
||||
|
||||
memory_usage_factor = 4.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.ACEStep15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
|
||||
detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
if "dtype_llama" in detect_2b:
|
||||
detect = detect_2b
|
||||
detect["lm_model"] = "qwen3_2b"
|
||||
elif "dtype_llama" in detect_4b:
|
||||
detect = detect_4b
|
||||
detect["lm_model"] = "qwen3_4b"
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
348
comfy/text_encoders/ace15.py
Normal file
348
comfy/text_encoders/ace15.py
Normal file
@@ -0,0 +1,348 @@
|
||||
from .anima import Qwen3Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import math
|
||||
import yaml
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def sample_manual_loop_no_classes(
|
||||
model,
|
||||
ids=None,
|
||||
execution_dtype=None,
|
||||
cfg_scale: float = 2.0,
|
||||
temperature: float = 0.85,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = None,
|
||||
min_p: float = 0.000,
|
||||
seed: int = 1,
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
audio_start_id: int = 151669, # The cutoff ID for audio codes
|
||||
audio_end_id: int = 215669,
|
||||
eos_token_id: int = 151645,
|
||||
):
|
||||
if ids is None:
|
||||
return []
|
||||
device = model.execution_device
|
||||
|
||||
if execution_dtype is None:
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
execution_dtype = torch.bfloat16
|
||||
else:
|
||||
execution_dtype = torch.float32
|
||||
|
||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||
embeds_batch = embeds.shape[0]
|
||||
|
||||
output_audio_codes = []
|
||||
past_key_values = []
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(seed)
|
||||
model_config = model.transformer.model.config
|
||||
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
|
||||
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
past_key_values = outputs[2]
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
cond_logits = next_token_logits[0:1]
|
||||
uncond_logits = next_token_logits[1:2]
|
||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||
else:
|
||||
cfg_logits = next_token_logits[0:1]
|
||||
|
||||
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
|
||||
if use_eos_score:
|
||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||
|
||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||
# Only generate audio tokens
|
||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||
|
||||
if use_eos_score:
|
||||
cfg_logits[:, eos_token_id] = eos_score
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
top_k_vals, _ = torch.topk(cfg_logits, top_k)
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if min_p is not None and min_p > 0:
|
||||
probs = torch.softmax(cfg_logits, dim=-1)
|
||||
p_max = probs.max(dim=-1, keepdim=True).values
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if temperature > 0:
|
||||
cfg_logits = cfg_logits / temperature
|
||||
next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1)
|
||||
else:
|
||||
next_token = torch.argmax(cfg_logits, dim=-1)
|
||||
|
||||
token = next_token.item()
|
||||
|
||||
if token == eos_token_id:
|
||||
break
|
||||
|
||||
embed, _, _, _ = model.process_tokens([[token]], device)
|
||||
embeds = embed.repeat(embeds_batch, 1, 1)
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
|
||||
output_audio_codes.append(token - audio_start_id)
|
||||
progress_bar.update_absolute(step)
|
||||
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
positive = positive[0]
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
negative = negative[0]
|
||||
|
||||
neg_pad = 0
|
||||
if len(negative) < len(positive):
|
||||
neg_pad = (len(positive) - len(negative))
|
||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||
|
||||
pos_pad = 0
|
||||
if len(negative) > len(positive):
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
ids = [positive, negative]
|
||||
else:
|
||||
ids = [positive]
|
||||
|
||||
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
|
||||
|
||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||
user_metas = {
|
||||
k: kwargs.pop(k)
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature")
|
||||
if k in kwargs
|
||||
}
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
user_metas = {
|
||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||
for k, v in user_metas.items()
|
||||
if v not in {"unspecified", None}
|
||||
}
|
||||
if len(user_metas):
|
||||
meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip()
|
||||
else:
|
||||
meta_yaml = ""
|
||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||
|
||||
def _metas_to_cap(self, **kwargs) -> str:
|
||||
use_keys = ("bpm", "timesignature", "keyscale", "duration")
|
||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
duration = user_metas["duration"]
|
||||
if duration == "N/A":
|
||||
user_metas["duration"] = "30 seconds"
|
||||
elif isinstance(duration, (str, int, float)):
|
||||
user_metas["duration"] = f"{math.ceil(float(duration))} seconds"
|
||||
else:
|
||||
raise TypeError("Unexpected type for duration key, must be str, int or float")
|
||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
text = text.strip()
|
||||
text_negative = kwargs.get("caption_negative", text).strip()
|
||||
lyrics = kwargs.get("lyrics", "")
|
||||
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
|
||||
duration = kwargs.get("duration", 120)
|
||||
if isinstance(duration, str):
|
||||
duration = float(duration.split(None, 1)[0])
|
||||
language = kwargs.get("language")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
generate_audio_codes = kwargs.get("generate_audio_codes", True)
|
||||
cfg_scale = kwargs.get("cfg_scale", 2.0)
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
min_p = kwargs.get("min_p", 0.000)
|
||||
|
||||
duration = math.ceil(duration)
|
||||
kwargs["duration"] = duration
|
||||
tokens_duration = duration * 5
|
||||
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
|
||||
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
|
||||
|
||||
metas_negative = {
|
||||
k.rsplit("_", 1)[0]: kwargs.pop(k)
|
||||
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
|
||||
if k in kwargs
|
||||
}
|
||||
if not kwargs.get("use_negative_caption"):
|
||||
_ = metas_negative.pop("caption", None)
|
||||
|
||||
cot_text = self._metas_to_cot(caption=text, **kwargs)
|
||||
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
|
||||
meta_cap = self._metas_to_cap(**kwargs)
|
||||
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
|
||||
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
|
||||
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
|
||||
|
||||
llm_prompts = {
|
||||
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
|
||||
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
|
||||
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
|
||||
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
|
||||
}
|
||||
|
||||
out = {
|
||||
prompt_key: self.qwen3_06b.tokenize_with_weights(
|
||||
prompt,
|
||||
prompt_key == "qwen3_06b" and return_word_ids,
|
||||
disable_weights = True,
|
||||
**kwargs,
|
||||
)
|
||||
for prompt_key, prompt in llm_prompts.items()
|
||||
}
|
||||
out["lm_metadata"] = {"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"generate_audio_codes": generate_audio_codes,
|
||||
"cfg_scale": cfg_scale,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class ACE15TEModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
|
||||
super().__init__()
|
||||
if dtype_llama is None:
|
||||
dtype_llama = dtype
|
||||
|
||||
model = None
|
||||
self.constant = 0.4375
|
||||
if lm_model == "qwen3_4b":
|
||||
model = Qwen3_4B_ACE15
|
||||
self.constant = 0.5625
|
||||
elif lm_model == "qwen3_2b":
|
||||
model = Qwen3_2B_ACE15
|
||||
|
||||
self.lm_model = lm_model
|
||||
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
|
||||
if model is not None:
|
||||
setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
|
||||
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_base = token_weight_pairs["qwen3_06b"]
|
||||
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
||||
|
||||
self.qwen3_06b.set_clip_options({"layer": None})
|
||||
base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base)
|
||||
self.qwen3_06b.set_clip_options({"layer": [0]})
|
||||
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
|
||||
|
||||
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.qwen3_06b.set_clip_options(options)
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.qwen3_06b.reset_clip_options()
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
shape = sd["model.layers.0.post_attention_layernorm.weight"].shape
|
||||
if shape[0] == 1024:
|
||||
return self.qwen3_06b.load_sd(sd)
|
||||
else:
|
||||
return getattr(self, self.lm_model).load_sd(sd)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
constant = self.constant
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant *= 0.5
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens += lm_metadata['min_tokens']
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
||||
class ACE15TEModel_(ACE15TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
|
||||
return ACE15TEModel_
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -23,7 +23,7 @@ class AnimaTokenizer:
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class MistralTokenizerClass:
|
||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"tekken_model": self.tekken_data}
|
||||
@@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class KleinTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
|
||||
|
||||
@@ -10,9 +10,11 @@ import comfy.utils
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
t5_key = "{}model.norm.weight".format(prefix)
|
||||
if t5_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||
norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)]
|
||||
for norm_key in norm_keys:
|
||||
if norm_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[norm_key].dtype
|
||||
break
|
||||
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Tuple
|
||||
import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.clip_model
|
||||
|
||||
@@ -32,6 +33,7 @@ class Llama2Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Mistral3Small24BConfig:
|
||||
@@ -54,6 +56,7 @@ class Mistral3Small24BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@@ -76,6 +79,7 @@ class Qwen25_3BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06BConfig:
|
||||
@@ -98,6 +102,76 @@ class Qwen3_06BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06B_ACE15_Config:
|
||||
vocab_size: int = 151669
|
||||
hidden_size: int = 1024
|
||||
intermediate_size: int = 3072
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_2B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
@@ -120,6 +194,7 @@ class Qwen3_4BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_8BConfig:
|
||||
@@ -142,6 +217,7 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
@@ -164,6 +240,7 @@ class Ovis25_2BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@@ -186,6 +263,7 @@ class Qwen25_7BVLI_Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@@ -209,6 +287,7 @@ class Gemma2_2B_Config:
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@@ -232,6 +311,7 @@ class Gemma3_4B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
@@ -255,6 +335,7 @@ class Gemma3_12B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
@@ -274,13 +355,6 @@ class RMSNorm(nn.Module):
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||
if not isinstance(theta, list):
|
||||
theta = [theta]
|
||||
@@ -309,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
out.append((cos, sin))
|
||||
sin_split = sin.shape[-1] // 2
|
||||
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||
|
||||
if len(out) == 1:
|
||||
return out[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
org_dtype = xq.dtype
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
nsin = freqs_cis[2]
|
||||
|
||||
q_embed = (xq * cos)
|
||||
q_split = q_embed.shape[-1] // 2
|
||||
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||
|
||||
k_embed = (xk * cos)
|
||||
k_split = k_embed.shape[-1] // 2
|
||||
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||
|
||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||
|
||||
|
||||
@@ -356,6 +440,7 @@ class Attention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
xq = self.q_proj(hidden_states)
|
||||
@@ -373,11 +458,30 @@ class Attention(nn.Module):
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
index = 0
|
||||
num_tokens = xk.shape[2]
|
||||
if len(past_key_value) > 0:
|
||||
past_key, past_value, index = past_key_value
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + xk.shape[2]] = xk
|
||||
past_value[:, :, index:index + xv.shape[2]] = xv
|
||||
xk = past_key[:, :, :index + xk.shape[2]]
|
||||
xv = past_value[:, :, :index + xv.shape[2]]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
return self.o_proj(output)
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
@@ -408,15 +512,17 @@ class TransformerBlock(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
x, present_key_value = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
@@ -426,7 +532,7 @@ class TransformerBlock(nn.Module):
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
return x, present_key_value
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
|
||||
@@ -451,6 +557,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
@@ -468,11 +575,12 @@ class TransformerBlockGemma2(nn.Module):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
x, present_key_value = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
@@ -485,7 +593,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
x = self.post_feedforward_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
return x, present_key_value
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
@@ -516,9 +624,10 @@ class Llama2_(nn.Module):
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@@ -527,8 +636,13 @@ class Llama2_(nn.Module):
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = past_key_values[0][2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
@@ -539,14 +653,16 @@ class Llama2_(nn.Module):
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
@@ -562,16 +678,27 @@ class Llama2_(nn.Module):
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
next_key_values = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
x = layer(
|
||||
|
||||
past_kv = None
|
||||
if past_key_values is not None:
|
||||
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
|
||||
|
||||
x, current_kv = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_kv,
|
||||
)
|
||||
|
||||
if current_kv is not None:
|
||||
next_key_values.append(current_kv)
|
||||
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
@@ -588,7 +715,10 @@ class Llama2_(nn.Module):
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
if len(next_key_values) > 0:
|
||||
return x, intermediate, next_key_values
|
||||
else:
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||
@@ -635,6 +765,21 @@ class BaseLlama:
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
class Llama2(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
@@ -663,7 +808,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06BConfig(**config_dict)
|
||||
@@ -672,7 +817,25 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06B_ACE15_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4BConfig(**config_dict)
|
||||
@@ -681,7 +844,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_8BConfig(**config_dict)
|
||||
|
||||
@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
@@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||
|
||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
@@ -125,7 +126,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False)
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||
|
||||
@@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens = max(num_tokens, 64)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
170
comfy/utils.py
170
comfy/utils.py
@@ -20,41 +20,92 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import comfy.checkpoint_pickle
|
||||
import comfy.memory_management
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
import json
|
||||
import time
|
||||
import mmap
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
|
||||
if True: # ckpt/pt file whitelist for safe loading of old sd files
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
def scalar(*args, **kwargs):
|
||||
from numpy.core.multiarray import scalar as sc
|
||||
return sc(*args, **kwargs)
|
||||
return None
|
||||
scalar.__module__ = "numpy.core.multiarray"
|
||||
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
from _codecs import encode
|
||||
|
||||
def encode(*args, **kwargs): # no longer necessary on newer torch
|
||||
return None
|
||||
encode.__module__ = "_codecs"
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
|
||||
# Current as of safetensors 0.7.0
|
||||
_TYPES = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I64": torch.int64,
|
||||
"I32": torch.int32,
|
||||
"I16": torch.int16,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"BOOL": torch.bool,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
"C64": torch.complex64,
|
||||
|
||||
"U64": torch.uint64,
|
||||
"U32": torch.uint32,
|
||||
"U16": torch.uint16,
|
||||
}
|
||||
|
||||
def load_safetensors(ckpt):
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
|
||||
mv = mv[8 + header_size:]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
start, end = info["data_offsets"]
|
||||
if start == end:
|
||||
sd[name] = torch.empty(info["shape"], dtype =_TYPES[info["dtype"]])
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
@@ -62,15 +113,20 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
if enables_dynamic_vram():
|
||||
sd, metadata = load_safetensors(ckpt)
|
||||
if not return_metadata:
|
||||
metadata = None
|
||||
else:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
@@ -84,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
@@ -619,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
@@ -645,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||
}
|
||||
@@ -1100,6 +1153,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
@@ -1308,3 +1387,34 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
def deepcopy_list_dict(obj, memo=None):
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, dict):
|
||||
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
res = [deepcopy_list_dict(i, memo) for i in obj]
|
||||
else:
|
||||
res = obj
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
@@ -5,6 +5,11 @@ from .lokr import LoKrAdapter
|
||||
from .glora import GLoRAAdapter
|
||||
from .oft import OFTAdapter
|
||||
from .boft import BOFTAdapter
|
||||
from .bypass import (
|
||||
BypassInjectionManager,
|
||||
BypassForwardHook,
|
||||
create_bypass_injections_from_patches,
|
||||
)
|
||||
|
||||
|
||||
adapters: list[type[WeightAdapterBase]] = [
|
||||
@@ -31,4 +36,7 @@ __all__ = [
|
||||
"WeightAdapterTrainBase",
|
||||
"adapters",
|
||||
"adapter_maps",
|
||||
"BypassInjectionManager",
|
||||
"BypassForwardHook",
|
||||
"create_bypass_injections_from_patches",
|
||||
] + [a.__name__ for a in adapters]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -7,12 +7,35 @@ import comfy.model_management
|
||||
|
||||
|
||||
class WeightAdapterBase:
|
||||
"""
|
||||
Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
||||
|
||||
Bypass Mode:
|
||||
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
||||
|
||||
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
||||
- g(y): Output transformation. Applied after base + h(x).
|
||||
|
||||
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
||||
For OFT/BOFT: g = transform, h = 0
|
||||
"""
|
||||
|
||||
name: str
|
||||
loaded_keys: set[str]
|
||||
weights: list[torch.Tensor]
|
||||
|
||||
# Attributes set by bypass system
|
||||
multiplier: float = 1.0
|
||||
shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel)
|
||||
|
||||
@classmethod
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
|
||||
def load(
|
||||
cls,
|
||||
x: str,
|
||||
lora: dict[str, torch.Tensor],
|
||||
alpha: float,
|
||||
dora_scale: torch.Tensor,
|
||||
) -> Optional["WeightAdapterBase"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_train(self) -> "WeightAdapterTrainBase":
|
||||
@@ -39,18 +62,202 @@ class WeightAdapterBase:
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
# ===== Bypass Mode Methods =====
|
||||
#
|
||||
# IMPORTANT: Bypass mode is designed for quantized models where original weights
|
||||
# may not be accessible in a usable format. Therefore, h() and bypass_forward()
|
||||
# do NOT take org_weight as a parameter. All necessary information (out_channels,
|
||||
# in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook.
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component: h(x, base_out)
|
||||
|
||||
Computes the adapter's contribution to be added to base forward output.
|
||||
For adapters that only transform output (OFT/BOFT), returns zeros.
|
||||
|
||||
Note:
|
||||
This method does NOT access original model weights. Bypass mode is
|
||||
designed for quantized models where weights may not be in a usable format.
|
||||
All shape info comes from module attributes set by BypassForwardHook.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward f(x), can be used for shape reference
|
||||
|
||||
Returns:
|
||||
Delta tensor to add to base output. Shape matches base output.
|
||||
|
||||
Reference: LyCORIS LoConModule.bypass_forward_diff
|
||||
"""
|
||||
# Default: no additive component (for OFT/BOFT)
|
||||
# Simply return zeros matching base_out shape
|
||||
return torch.zeros_like(base_out)
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation: g(y)
|
||||
|
||||
Applied after base forward + h(x). For most adapters this is identity.
|
||||
OFT/BOFT override this to apply orthogonal transformation.
|
||||
|
||||
Args:
|
||||
y: Combined output (base + h(x))
|
||||
|
||||
Returns:
|
||||
Transformed output
|
||||
|
||||
Reference: LyCORIS OFTModule applies orthogonal transform here
|
||||
"""
|
||||
# Default: identity (for LoRA/LoHa/LoKr)
|
||||
return y
|
||||
|
||||
def bypass_forward(
|
||||
self,
|
||||
org_forward: Callable,
|
||||
x: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Full bypass forward: g(f(x) + h(x, f(x)))
|
||||
|
||||
Note:
|
||||
This method does NOT take org_weight/org_bias parameters. Bypass mode
|
||||
is designed for quantized models where weights may not be accessible.
|
||||
The original forward function handles weight access internally.
|
||||
|
||||
Args:
|
||||
org_forward: Original module forward function
|
||||
x: Input tensor
|
||||
*args, **kwargs: Additional arguments for org_forward
|
||||
|
||||
Returns:
|
||||
Output with adapter applied in bypass mode
|
||||
|
||||
Reference: LyCORIS LoConModule.bypass_forward
|
||||
"""
|
||||
# Base forward: f(x)
|
||||
base_out = org_forward(x, *args, **kwargs)
|
||||
|
||||
# Additive component: h(x, base_out) - base_out provided for shape reference
|
||||
h_out = self.h(x, base_out)
|
||||
|
||||
# Output transformation: g(base + h)
|
||||
return self.g(base_out + h_out)
|
||||
|
||||
|
||||
class WeightAdapterTrainBase(nn.Module):
|
||||
# We follow the scheme of PR #7032
|
||||
"""
|
||||
Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
||||
|
||||
Bypass Mode:
|
||||
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
||||
|
||||
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
||||
- g(y): Output transformation. Applied after base + h(x).
|
||||
|
||||
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
||||
For OFT: g = transform, h = 0
|
||||
|
||||
Note:
|
||||
Unlike WeightAdapterBase, TrainBase classes have simplified weight formats
|
||||
with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition).
|
||||
|
||||
We follow the scheme of PR #7032
|
||||
"""
|
||||
|
||||
# Attributes set by bypass system (BypassForwardHook)
|
||||
# These are set before h()/g()/bypass_forward() are called
|
||||
multiplier: float = 1.0
|
||||
is_conv: bool = False
|
||||
conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d
|
||||
kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups
|
||||
kernel_size: tuple = ()
|
||||
in_channels: int = None
|
||||
out_channels: int = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, w):
|
||||
"""
|
||||
w: The original weight tensor to be modified.
|
||||
Weight modification mode: returns modified weight.
|
||||
|
||||
Args:
|
||||
w: The original weight tensor to be modified.
|
||||
|
||||
Returns:
|
||||
Modified weight tensor.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# ===== Bypass Mode Methods =====
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component: h(x, base_out)
|
||||
|
||||
Computes the adapter's contribution to be added to base forward output.
|
||||
For adapters that only transform output (OFT), returns zeros.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward f(x), can be used for shape reference
|
||||
|
||||
Returns:
|
||||
Delta tensor to add to base output. Shape matches base output.
|
||||
|
||||
Subclasses should override this method.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.h() not implemented. "
|
||||
"Subclasses must implement h() for bypass mode."
|
||||
)
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation: g(y)
|
||||
|
||||
Applied after base forward + h(x). For most adapters this is identity.
|
||||
OFT overrides this to apply orthogonal transformation.
|
||||
|
||||
Args:
|
||||
y: Combined output (base + h(x))
|
||||
|
||||
Returns:
|
||||
Transformed output
|
||||
"""
|
||||
# Default: identity (for LoRA/LoHa/LoKr)
|
||||
return y
|
||||
|
||||
def bypass_forward(
|
||||
self,
|
||||
org_forward: Callable,
|
||||
x: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Full bypass forward: g(f(x) + h(x, f(x)))
|
||||
|
||||
Args:
|
||||
org_forward: Original module forward function
|
||||
x: Input tensor
|
||||
*args, **kwargs: Additional arguments for org_forward
|
||||
|
||||
Returns:
|
||||
Output with adapter applied in bypass mode
|
||||
"""
|
||||
# Base forward: f(x)
|
||||
base_out = org_forward(x, *args, **kwargs)
|
||||
|
||||
# Additive component: h(x, base_out) - base_out provided for shape reference
|
||||
h_out = self.h(x, base_out)
|
||||
|
||||
# Output transformation: g(base + h)
|
||||
return self.g(base_out + h_out)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
raise NotImplementedError("passive_memory_usage is not implemented")
|
||||
|
||||
@@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module):
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
def weight_decompose(
|
||||
dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function
|
||||
):
|
||||
dora_scale = comfy.model_management.cast_to_device(
|
||||
dora_scale, weight.device, intermediate_dtype
|
||||
)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
|
||||
@@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
the original tensor will be truncated in that dimension.
|
||||
"""
|
||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||
raise ValueError(
|
||||
"The new shape must be larger than the original tensor in all dimensions"
|
||||
)
|
||||
|
||||
if len(new_shape) != len(tensor.shape):
|
||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||
raise ValueError(
|
||||
"The new shape must have the same number of dimensions as the original tensor"
|
||||
)
|
||||
|
||||
# Create a new tensor filled with zeros
|
||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
@@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
alpha = v[2]
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
blocks = comfy.model_management.cast_to_device(
|
||||
blocks, weight.device, intermediate_dtype
|
||||
)
|
||||
if rescale is not None:
|
||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||
rescale = comfy.model_management.cast_to_device(
|
||||
rescale, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||
|
||||
@@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(-1, -2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
@@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
r = r.to(weight)
|
||||
inp = org = weight
|
||||
|
||||
r_b = boft_b//2
|
||||
r_b = boft_b // 2
|
||||
for i in range(boft_m):
|
||||
bi = r[i]
|
||||
g = 2
|
||||
k = 2**i * r_b
|
||||
if strength != 1:
|
||||
bi = bi * strength + (1-strength) * I
|
||||
bi = bi * strength + (1 - strength) * I
|
||||
inp = (
|
||||
inp.unflatten(0, (-1, g, k))
|
||||
.transpose(1, 2)
|
||||
@@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
)
|
||||
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
||||
inp = (
|
||||
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
||||
inp.flatten(0, 1)
|
||||
.unflatten(0, (-1, k, g))
|
||||
.transpose(1, 2)
|
||||
.flatten(0, 2)
|
||||
)
|
||||
|
||||
if rescale is not None:
|
||||
inp = inp * rescale
|
||||
|
||||
lora_diff = inp - org
|
||||
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
||||
lora_diff = comfy.model_management.cast_to_device(
|
||||
lora_diff, weight.device, intermediate_dtype
|
||||
)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def _get_orthogonal_matrices(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrices R from BOFT blocks."""
|
||||
v = self.weights
|
||||
blocks = v[0].to(device=device, dtype=dtype)
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
|
||||
boft_m, block_num, boft_b, _ = blocks.shape
|
||||
I = torch.eye(boft_b, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(-1, -2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if alpha > 0
|
||||
if alpha > 0:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r, boft_m, boft_b
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for BOFT: applies butterfly orthogonal transform.
|
||||
|
||||
BOFT uses multiple stages of butterfly-structured orthogonal transforms.
|
||||
|
||||
Reference: LyCORIS ButterflyOFTModule._bypass_forward
|
||||
"""
|
||||
v = self.weights
|
||||
rescale = v[1]
|
||||
|
||||
r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype)
|
||||
r_b = boft_b // 2
|
||||
|
||||
# Apply multiplier
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(boft_b, device=y.device, dtype=y.dtype)
|
||||
|
||||
# Use module info from bypass injection to determine conv vs linear
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# Apply butterfly transform stages
|
||||
inp = y
|
||||
for i in range(boft_m):
|
||||
bi = r[i] # (block_num, boft_b, boft_b)
|
||||
g = 2
|
||||
k = 2**i * r_b
|
||||
|
||||
# Interpolate with identity based on multiplier
|
||||
if multiplier != 1:
|
||||
bi = bi * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten
|
||||
inp = (
|
||||
inp.unflatten(-1, (-1, g, k))
|
||||
.transpose(-2, -1)
|
||||
.flatten(-3)
|
||||
.unflatten(-1, (-1, boft_b))
|
||||
)
|
||||
# Apply block-diagonal orthogonal transform
|
||||
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
|
||||
# Reshape back
|
||||
inp = (
|
||||
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
||||
)
|
||||
|
||||
# Apply rescale if present
|
||||
if rescale is not None:
|
||||
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
||||
inp = inp * rescale.transpose(0, -1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
inp = inp.transpose(1, -1)
|
||||
|
||||
return inp
|
||||
|
||||
441
comfy/weight_adapter/bypass.py
Normal file
441
comfy/weight_adapter/bypass.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.)
|
||||
|
||||
Bypass mode applies adapters during forward pass without modifying base weights:
|
||||
bypass(f)(x) = g(f(x) + h(x))
|
||||
|
||||
Where:
|
||||
- f(x): Original layer forward
|
||||
- h(x): Additive component from adapter (LoRA path)
|
||||
- g(y): Output transformation (identity for most adapters)
|
||||
|
||||
This is useful for:
|
||||
- Training with gradient checkpointing
|
||||
- Avoiding weight modifications when weights are offloaded
|
||||
- Supporting multiple adapters with different strengths dynamically
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from comfy.patcher_extension import PatcherInjection
|
||||
|
||||
# Type alias for adapters that support bypass mode
|
||||
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
|
||||
|
||||
|
||||
def get_module_type_info(module: nn.Module) -> dict:
|
||||
"""
|
||||
Determine module type and extract conv parameters from module class.
|
||||
|
||||
This is more reliable than checking weight.ndim, especially for quantized layers
|
||||
where weight shape might be different.
|
||||
|
||||
Returns:
|
||||
dict with keys: is_conv, conv_dim, stride, padding, dilation, groups
|
||||
"""
|
||||
info = {
|
||||
"is_conv": False,
|
||||
"conv_dim": 0,
|
||||
"stride": (1,),
|
||||
"padding": (0,),
|
||||
"dilation": (1,),
|
||||
"groups": 1,
|
||||
"kernel_size": (1,),
|
||||
"in_channels": None,
|
||||
"out_channels": None,
|
||||
}
|
||||
|
||||
# Determine conv type
|
||||
if isinstance(module, nn.Conv1d):
|
||||
info["is_conv"] = True
|
||||
info["conv_dim"] = 1
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
info["is_conv"] = True
|
||||
info["conv_dim"] = 2
|
||||
elif isinstance(module, nn.Conv3d):
|
||||
info["is_conv"] = True
|
||||
info["conv_dim"] = 3
|
||||
elif isinstance(module, nn.Linear):
|
||||
info["is_conv"] = False
|
||||
info["conv_dim"] = 0
|
||||
else:
|
||||
# Try to infer from class name for custom/quantized layers
|
||||
class_name = type(module).__name__.lower()
|
||||
if "conv3d" in class_name:
|
||||
info["is_conv"] = True
|
||||
info["conv_dim"] = 3
|
||||
elif "conv2d" in class_name:
|
||||
info["is_conv"] = True
|
||||
info["conv_dim"] = 2
|
||||
elif "conv1d" in class_name:
|
||||
info["is_conv"] = True
|
||||
info["conv_dim"] = 1
|
||||
elif "conv" in class_name:
|
||||
info["is_conv"] = True
|
||||
info["conv_dim"] = 2
|
||||
|
||||
# Extract conv parameters if it's a conv layer
|
||||
if info["is_conv"]:
|
||||
# Try to get stride, padding, dilation, groups, kernel_size from module
|
||||
info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"])
|
||||
info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"])
|
||||
info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"])
|
||||
info["groups"] = getattr(module, "groups", 1)
|
||||
info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"])
|
||||
info["in_channels"] = getattr(module, "in_channels", None)
|
||||
info["out_channels"] = getattr(module, "out_channels", None)
|
||||
|
||||
# Ensure they're tuples
|
||||
if isinstance(info["stride"], int):
|
||||
info["stride"] = (info["stride"],) * info["conv_dim"]
|
||||
if isinstance(info["padding"], int):
|
||||
info["padding"] = (info["padding"],) * info["conv_dim"]
|
||||
if isinstance(info["dilation"], int):
|
||||
info["dilation"] = (info["dilation"],) * info["conv_dim"]
|
||||
if isinstance(info["kernel_size"], int):
|
||||
info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"]
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class BypassForwardHook:
|
||||
"""
|
||||
Hook that wraps a layer's forward to apply adapter in bypass mode.
|
||||
|
||||
Stores the original forward and replaces it with bypass version.
|
||||
|
||||
Supports both:
|
||||
- WeightAdapterBase: Inference adapters (uses self.weights tuple)
|
||||
- WeightAdapterTrainBase: Training adapters (nn.Module with parameters)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
adapter: BypassAdapter,
|
||||
multiplier: float = 1.0,
|
||||
):
|
||||
self.module = module
|
||||
self.adapter = adapter
|
||||
self.multiplier = multiplier
|
||||
self.original_forward = None
|
||||
|
||||
# Determine layer type and conv params from module class (works for quantized layers)
|
||||
module_info = get_module_type_info(module)
|
||||
|
||||
# Set multiplier and layer type info on adapter for use in h()
|
||||
adapter.multiplier = multiplier
|
||||
adapter.is_conv = module_info["is_conv"]
|
||||
adapter.conv_dim = module_info["conv_dim"]
|
||||
adapter.kernel_size = module_info["kernel_size"]
|
||||
adapter.in_channels = module_info["in_channels"]
|
||||
adapter.out_channels = module_info["out_channels"]
|
||||
# Store kw_dict for conv operations (like LyCORIS extra_args)
|
||||
if module_info["is_conv"]:
|
||||
adapter.kw_dict = {
|
||||
"stride": module_info["stride"],
|
||||
"padding": module_info["padding"],
|
||||
"dilation": module_info["dilation"],
|
||||
"groups": module_info["groups"],
|
||||
}
|
||||
else:
|
||||
adapter.kw_dict = {}
|
||||
|
||||
def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
"""Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x))
|
||||
|
||||
Note:
|
||||
Bypass mode does NOT access original model weights (org_weight).
|
||||
This is intentional - bypass mode is designed for quantized models
|
||||
where weights may not be in a usable format. All necessary shape
|
||||
information is provided via adapter attributes set during inject().
|
||||
"""
|
||||
# Check if adapter has custom bypass_forward (e.g., GLoRA)
|
||||
adapter_bypass = getattr(self.adapter, "bypass_forward", None)
|
||||
if adapter_bypass is not None:
|
||||
# Check if it's overridden (not the base class default)
|
||||
# Need to check both base classes since adapter could be either type
|
||||
adapter_type = type(self.adapter)
|
||||
is_default_bypass = (
|
||||
adapter_type.bypass_forward is WeightAdapterBase.bypass_forward
|
||||
or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward
|
||||
)
|
||||
if not is_default_bypass:
|
||||
return adapter_bypass(self.original_forward, x, *args, **kwargs)
|
||||
|
||||
# Default bypass: g(f(x) + h(x, f(x)))
|
||||
base_out = self.original_forward(x, *args, **kwargs)
|
||||
h_out = self.adapter.h(x, base_out)
|
||||
return self.adapter.g(base_out + h_out)
|
||||
|
||||
def inject(self):
|
||||
"""Replace module forward with bypass version."""
|
||||
if self.original_forward is not None:
|
||||
logging.debug(
|
||||
f"[BypassHook] Already injected for {type(self.module).__name__}"
|
||||
)
|
||||
return # Already injected
|
||||
|
||||
# Move adapter weights to compute device (GPU)
|
||||
# Use get_torch_device() instead of module.weight.device because
|
||||
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Get dtype from module weight if available
|
||||
dtype = None
|
||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||
dtype = self.module.weight.dtype
|
||||
|
||||
# Only use dtype if it's a standard float type, not quantized
|
||||
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||
dtype = None
|
||||
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
|
||||
self.original_forward = self.module.forward
|
||||
self.module.forward = self._bypass_forward
|
||||
logging.debug(
|
||||
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
|
||||
)
|
||||
|
||||
def _move_adapter_weights_to_device(self, device, dtype=None):
|
||||
"""Move adapter weights to specified device to avoid per-forward transfers.
|
||||
|
||||
Handles both:
|
||||
- WeightAdapterBase: has self.weights tuple of tensors
|
||||
- WeightAdapterTrainBase: nn.Module with parameters, uses .to() method
|
||||
"""
|
||||
adapter = self.adapter
|
||||
|
||||
# Check if adapter is an nn.Module (WeightAdapterTrainBase)
|
||||
if isinstance(adapter, nn.Module):
|
||||
# In training mode we don't touch dtype as trainer will handle it
|
||||
adapter.to(device=device)
|
||||
logging.debug(
|
||||
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
|
||||
)
|
||||
return
|
||||
|
||||
# WeightAdapterBase: handle self.weights tuple
|
||||
if not hasattr(adapter, "weights") or adapter.weights is None:
|
||||
return
|
||||
|
||||
weights = adapter.weights
|
||||
if isinstance(weights, (list, tuple)):
|
||||
new_weights = []
|
||||
for w in weights:
|
||||
if isinstance(w, torch.Tensor):
|
||||
if dtype is not None:
|
||||
new_weights.append(w.to(device=device, dtype=dtype))
|
||||
else:
|
||||
new_weights.append(w.to(device=device))
|
||||
else:
|
||||
new_weights.append(w)
|
||||
adapter.weights = (
|
||||
tuple(new_weights) if isinstance(weights, tuple) else new_weights
|
||||
)
|
||||
elif isinstance(weights, torch.Tensor):
|
||||
if dtype is not None:
|
||||
adapter.weights = weights.to(device=device, dtype=dtype)
|
||||
else:
|
||||
adapter.weights = weights.to(device=device)
|
||||
|
||||
logging.debug(f"[BypassHook] Moved adapter weights to {device}")
|
||||
|
||||
def eject(self):
|
||||
"""Restore original module forward."""
|
||||
if self.original_forward is None:
|
||||
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}")
|
||||
return # Not injected
|
||||
|
||||
self.module.forward = self.original_forward
|
||||
self.original_forward = None
|
||||
logging.debug(
|
||||
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
|
||||
)
|
||||
|
||||
|
||||
class BypassInjectionManager:
|
||||
"""
|
||||
Manages bypass mode injection for a collection of adapters.
|
||||
|
||||
Creates PatcherInjection objects that can be used with ModelPatcher.
|
||||
|
||||
Supports both inference adapters (WeightAdapterBase) and training adapters
|
||||
(WeightAdapterTrainBase).
|
||||
|
||||
Usage:
|
||||
manager = BypassInjectionManager()
|
||||
manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8)
|
||||
manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8)
|
||||
|
||||
injections = manager.create_injections(model)
|
||||
model_patcher.set_injections("bypass_lora", injections)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.adapters: dict[str, tuple[BypassAdapter, float]] = {}
|
||||
self.hooks: list[BypassForwardHook] = []
|
||||
|
||||
def add_adapter(
|
||||
self,
|
||||
key: str,
|
||||
adapter: BypassAdapter,
|
||||
strength: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Add an adapter for a specific weight key.
|
||||
|
||||
Args:
|
||||
key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight")
|
||||
adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.)
|
||||
strength: Multiplier for adapter effect
|
||||
"""
|
||||
# Remove .weight suffix if present for module lookup
|
||||
module_key = key
|
||||
if module_key.endswith(".weight"):
|
||||
module_key = module_key[:-7]
|
||||
logging.debug(
|
||||
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
|
||||
)
|
||||
|
||||
self.adapters[module_key] = (adapter, strength)
|
||||
logging.debug(
|
||||
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
|
||||
)
|
||||
|
||||
def clear_adapters(self):
|
||||
"""Remove all adapters."""
|
||||
self.adapters.clear()
|
||||
|
||||
def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]:
|
||||
"""Get a submodule by dot-separated key."""
|
||||
parts = key.split(".")
|
||||
module = model
|
||||
try:
|
||||
for i, part in enumerate(parts):
|
||||
if part.isdigit():
|
||||
module = module[int(part)]
|
||||
else:
|
||||
module = getattr(module, part)
|
||||
logging.debug(
|
||||
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
|
||||
)
|
||||
return module
|
||||
except (AttributeError, IndexError, KeyError) as e:
|
||||
logging.error(f"[BypassManager] Failed to find module for key {key}: {e}")
|
||||
logging.error(
|
||||
f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}"
|
||||
)
|
||||
return None
|
||||
|
||||
def create_injections(self, model: nn.Module) -> list[PatcherInjection]:
|
||||
"""
|
||||
Create PatcherInjection objects for all registered adapters.
|
||||
|
||||
Args:
|
||||
model: The model to inject into (e.g., model_patcher.model)
|
||||
|
||||
Returns:
|
||||
List of PatcherInjection objects to use with model_patcher.set_injections()
|
||||
"""
|
||||
self.hooks.clear()
|
||||
|
||||
logging.debug(
|
||||
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
|
||||
)
|
||||
logging.debug(f"[BypassManager] Model type: {type(model).__name__}")
|
||||
|
||||
for key, (adapter, strength) in self.adapters.items():
|
||||
logging.debug(f"[BypassManager] Looking for module: {key}")
|
||||
module = self._get_module_by_key(model, key)
|
||||
|
||||
if module is None:
|
||||
logging.warning(f"[BypassManager] Module not found for key {key}")
|
||||
continue
|
||||
|
||||
if not hasattr(module, "weight"):
|
||||
logging.warning(
|
||||
f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})"
|
||||
)
|
||||
continue
|
||||
|
||||
logging.debug(
|
||||
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
|
||||
)
|
||||
hook = BypassForwardHook(module, adapter, multiplier=strength)
|
||||
self.hooks.append(hook)
|
||||
|
||||
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks")
|
||||
|
||||
# Create single injection that manages all hooks
|
||||
def inject_all(model_patcher):
|
||||
logging.debug(
|
||||
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
|
||||
)
|
||||
for hook in self.hooks:
|
||||
hook.inject()
|
||||
logging.debug(
|
||||
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
|
||||
)
|
||||
|
||||
def eject_all(model_patcher):
|
||||
logging.debug(
|
||||
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
|
||||
)
|
||||
for hook in self.hooks:
|
||||
hook.eject()
|
||||
|
||||
return [PatcherInjection(inject=inject_all, eject=eject_all)]
|
||||
|
||||
def get_hook_count(self) -> int:
|
||||
"""Return number of hooks that will be/are injected."""
|
||||
return len(self.hooks)
|
||||
|
||||
|
||||
def create_bypass_injections_from_patches(
|
||||
model: nn.Module,
|
||||
patches: dict,
|
||||
strength: float = 1.0,
|
||||
) -> list[PatcherInjection]:
|
||||
"""
|
||||
Convenience function to create bypass injections from a patches dict.
|
||||
|
||||
This is useful when you have patches in the format used by model_patcher.add_patches()
|
||||
and want to apply them in bypass mode instead.
|
||||
|
||||
Args:
|
||||
model: The model to inject into
|
||||
patches: Dict mapping weight keys to adapter data
|
||||
strength: Global strength multiplier
|
||||
|
||||
Returns:
|
||||
List of PatcherInjection objects
|
||||
"""
|
||||
manager = BypassInjectionManager()
|
||||
|
||||
for key, patch_list in patches.items():
|
||||
if not patch_list:
|
||||
continue
|
||||
|
||||
# patches format: list of (strength_patch, patch_data, strength_model, offset, function)
|
||||
for patch in patch_list:
|
||||
patch_strength, patch_data, strength_model, offset, function = patch
|
||||
|
||||
# patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple
|
||||
if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)):
|
||||
adapter = patch_data
|
||||
else:
|
||||
# Skip non-adapter patches
|
||||
continue
|
||||
|
||||
combined_strength = strength * patch_strength
|
||||
manager.add_adapter(key, adapter, strength=combined_strength)
|
||||
|
||||
return manager.create_injections(model)
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
@@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase):
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||
weights = (
|
||||
lora[a1_name],
|
||||
lora[a2_name],
|
||||
lora[b1_name],
|
||||
lora[b2_name],
|
||||
alpha,
|
||||
dora_scale,
|
||||
)
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
@@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase):
|
||||
old_glora = True
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
if (
|
||||
old_glora
|
||||
and v[1].shape[0] == weight.shape[0]
|
||||
and weight.shape[0] == weight.shape[1]
|
||||
):
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = v[1].shape[0]
|
||||
|
||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a1 = comfy.model_management.cast_to_device(
|
||||
v[0].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
a2 = comfy.model_management.cast_to_device(
|
||||
v[1].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
b1 = comfy.model_management.cast_to_device(
|
||||
v[2].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
b2 = comfy.model_management.cast_to_device(
|
||||
v[3].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / rank
|
||||
@@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase):
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
lora_diff = (
|
||||
torch.mm(b2, b1)
|
||||
+ torch.mm(
|
||||
torch.mm(
|
||||
weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2
|
||||
),
|
||||
a1,
|
||||
)
|
||||
).reshape(
|
||||
weight.shape
|
||||
) # old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff = torch.einsum(
|
||||
"o i ..., i j -> o j ...",
|
||||
torch.einsum(
|
||||
"o i ..., i j -> o j ...",
|
||||
weight.to(dtype=intermediate_dtype),
|
||||
a1,
|
||||
),
|
||||
a2,
|
||||
).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff = torch.mm(
|
||||
torch.mm(weight.to(dtype=intermediate_dtype), a1), a2
|
||||
).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def _compute_paths(self, x: torch.Tensor):
|
||||
"""
|
||||
Compute A path and B path outputs for GLoRA bypass.
|
||||
|
||||
GLoRA: f(x) = Wx + WAx + Bx
|
||||
- A path: a1(a2(x)) - modifies input to base forward
|
||||
- B path: b1(b2(x)) - additive component
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Returns: (a_out, b_out)
|
||||
"""
|
||||
v = self.weights
|
||||
# v = (a1, a2, b1, b2, alpha, dora_scale)
|
||||
a1 = v[0]
|
||||
a2 = v[1]
|
||||
b1 = v[2]
|
||||
b2 = v[3]
|
||||
alpha = v[4]
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
# Cast dtype (weights should already be on correct device from inject())
|
||||
a1 = a1.to(dtype=dtype)
|
||||
a2 = a2.to(dtype=dtype)
|
||||
b1 = b1.to(dtype=dtype)
|
||||
b2 = b2.to(dtype=dtype)
|
||||
|
||||
# Determine rank and scale
|
||||
# Check for old vs new glora format
|
||||
old_glora = False
|
||||
if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]:
|
||||
rank = a1.shape[0]
|
||||
old_glora = True
|
||||
|
||||
if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]:
|
||||
if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
rank = a2.shape[0]
|
||||
|
||||
if alpha is not None:
|
||||
scale = alpha / rank
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
# Apply multiplier
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
scale = scale * multiplier
|
||||
|
||||
# Use module info from bypass injection, not input tensor shape
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
if is_conv:
|
||||
# Conv case - conv_dim is 1/2/3 for conv1d/2d/3d
|
||||
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||
|
||||
# Get module's stride/padding for spatial dimension handling
|
||||
module_stride = kw_dict.get("stride", (1,) * conv_dim)
|
||||
module_padding = kw_dict.get("padding", (0,) * conv_dim)
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
|
||||
# Ensure weights are in conv shape
|
||||
# a1, a2, b1 are always 1x1 kernels
|
||||
if a1.ndim == 2:
|
||||
a1 = a1.view(*a1.shape, *([1] * conv_dim))
|
||||
if a2.ndim == 2:
|
||||
a2 = a2.view(*a2.shape, *([1] * conv_dim))
|
||||
if b1.ndim == 2:
|
||||
b1 = b1.view(*b1.shape, *([1] * conv_dim))
|
||||
# b2 has actual kernel_size (like LoRA down)
|
||||
if b2.ndim == 2:
|
||||
if in_channels is not None:
|
||||
b2 = b2.view(b2.shape[0], in_channels, *kernel_size)
|
||||
else:
|
||||
b2 = b2.view(*b2.shape, *([1] * conv_dim))
|
||||
|
||||
# A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x
|
||||
a2_out = conv_fn(x, a2)
|
||||
a_out = conv_fn(a2_out, a1) * scale
|
||||
|
||||
# B path: b2(x) with kernel/stride/padding -> b1(...) 1x1
|
||||
b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding)
|
||||
b_out = conv_fn(b2_out, b1) * scale
|
||||
else:
|
||||
# Linear case
|
||||
if old_glora:
|
||||
# Old format: a1 @ a2 @ x, b2 @ b1
|
||||
a_out = F.linear(F.linear(x, a2), a1) * scale
|
||||
b_out = F.linear(F.linear(x, b1), b2) * scale
|
||||
else:
|
||||
# New format: x @ a1 @ a2, b1 @ b2
|
||||
a_out = F.linear(F.linear(x, a1), a2) * scale
|
||||
b_out = F.linear(F.linear(x, b2), b1) * scale
|
||||
|
||||
return a_out, b_out
|
||||
|
||||
def bypass_forward(
|
||||
self,
|
||||
org_forward: Callable,
|
||||
x: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GLoRA bypass forward: f(x + a(x)) + b(x)
|
||||
|
||||
Unlike standard adapters, GLoRA modifies the input to the base forward
|
||||
AND adds the B path output.
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Reference: LyCORIS GLoRAModule._bypass_forward
|
||||
"""
|
||||
a_out, b_out = self._compute_paths(x)
|
||||
|
||||
# Call base forward with modified input
|
||||
base_out = org_forward(x + a_out, *args, **kwargs)
|
||||
|
||||
# Add B path
|
||||
return base_out + b_out
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
For GLoRA, h() returns the B path output.
|
||||
|
||||
Note:
|
||||
GLoRA's full bypass requires overriding bypass_forward() since
|
||||
it also modifies the input to org_forward. This h() is provided for
|
||||
compatibility but bypass_forward() should be used for correct behavior.
|
||||
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
"""
|
||||
_, b_out = self._compute_paths(x)
|
||||
return b_out
|
||||
|
||||
@@ -1,11 +1,22 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||
|
||||
|
||||
@cache
|
||||
def _warn_loha_bypass_inefficient():
|
||||
"""One-time warning about LoHa bypass inefficiency."""
|
||||
logging.warning(
|
||||
"LoHa bypass mode is inefficient: full weight diff is computed each forward pass. "
|
||||
"Consider using LoRA or LoKr for training with bypass mode."
|
||||
)
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||
@@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase):
|
||||
|
||||
scale = self.alpha / self.rank
|
||||
if self.use_tucker:
|
||||
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
diff_weight = HadaWeightTucker.apply(
|
||||
self.hada_t1,
|
||||
self.hada_w1_a,
|
||||
self.hada_w1_b,
|
||||
self.hada_t2,
|
||||
self.hada_w2_a,
|
||||
self.hada_w2_b,
|
||||
scale,
|
||||
)
|
||||
else:
|
||||
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
diff_weight = HadaWeight.apply(
|
||||
self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale
|
||||
)
|
||||
|
||||
# Add the scaled difference to the original weight
|
||||
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||
@@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.normal_(mat3, 0.1)
|
||||
torch.nn.init.normal_(mat4, 0.01)
|
||||
return LohaDiff(
|
||||
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||
)
|
||||
return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None))
|
||||
|
||||
def to_train(self):
|
||||
return LohaDiff(self.weights)
|
||||
@@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||
weights = (
|
||||
lora[hada_w1_a_name],
|
||||
lora[hada_w1_b_name],
|
||||
alpha,
|
||||
lora[hada_w2_a_name],
|
||||
lora[hada_w2_b_name],
|
||||
hada_t1,
|
||||
hada_t2,
|
||||
dora_scale,
|
||||
)
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
@@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
if v[5] is not None: # cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
m1 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
comfy.model_management.cast_to_device(
|
||||
t1, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1b, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1a, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
m2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
comfy.model_management.cast_to_device(
|
||||
t2, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2b, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2a, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
m1 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w1a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
m2 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w2a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoHa: h(x) = diff_weight @ x
|
||||
|
||||
WARNING: Inefficient - computes full Hadamard product each forward.
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
|
||||
Reference: LyCORIS functional/loha.py bypass_forward_diff
|
||||
"""
|
||||
_warn_loha_bypass_inefficient()
|
||||
|
||||
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
|
||||
v = self.weights
|
||||
# v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
alpha = v[2]
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
|
||||
# Compute scale
|
||||
rank = w1b.shape[0]
|
||||
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
|
||||
self, "multiplier", 1.0
|
||||
)
|
||||
|
||||
# Cast dtype
|
||||
w1a = w1a.to(dtype=x.dtype)
|
||||
w1b = w1b.to(dtype=x.dtype)
|
||||
w2a = w2a.to(dtype=x.dtype)
|
||||
w2b = w2b.to(dtype=x.dtype)
|
||||
|
||||
# Use module info from bypass injection, not weight dimension
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
# Compute diff weight using Hadamard product
|
||||
if t1 is not None and t2 is not None:
|
||||
t1 = t1.to(dtype=x.dtype)
|
||||
t2 = t2.to(dtype=x.dtype)
|
||||
m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
|
||||
m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
|
||||
diff_weight = (m1 * m2) * scale
|
||||
else:
|
||||
m1 = w1a @ w1b
|
||||
m2 = w2a @ w2b
|
||||
diff_weight = (m1 * m2) * scale
|
||||
|
||||
if is_conv:
|
||||
op = FUNC_LIST[conv_dim + 2]
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
|
||||
# Reshape 2D diff_weight to conv format using kernel_size
|
||||
# diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size]
|
||||
if diff_weight.dim() == 2:
|
||||
if in_channels is not None:
|
||||
diff_weight = diff_weight.view(
|
||||
diff_weight.shape[0], in_channels, *kernel_size
|
||||
)
|
||||
else:
|
||||
diff_weight = diff_weight.view(
|
||||
*diff_weight.shape, *([1] * conv_dim)
|
||||
)
|
||||
else:
|
||||
op = F.linear
|
||||
kw_dict = {}
|
||||
|
||||
return op(x, diff_weight, **kw_dict)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
@@ -14,7 +15,17 @@ from .base import (
|
||||
class LokrDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||
(
|
||||
lokr_w1,
|
||||
lokr_w2,
|
||||
alpha,
|
||||
lokr_w1_a,
|
||||
lokr_w1_b,
|
||||
lokr_w2_a,
|
||||
lokr_w2_b,
|
||||
lokr_t2,
|
||||
dora_scale,
|
||||
) = weights
|
||||
self.use_tucker = False
|
||||
if lokr_w1_a is not None:
|
||||
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||
@@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase):
|
||||
if self.w2_rebuild:
|
||||
if self.use_tucker:
|
||||
w2 = torch.einsum(
|
||||
'i j k l, j r, i p -> p r k l',
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
self.lokr_t2,
|
||||
self.lokr_w2_b,
|
||||
self.lokr_w2_a
|
||||
self.lokr_w2_a,
|
||||
)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
@@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase):
|
||||
return self.lokr_w2
|
||||
|
||||
def __call__(self, w):
|
||||
diff = torch.kron(self.w1, self.w2)
|
||||
w1 = self.w1
|
||||
w2 = self.w2
|
||||
# Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron)
|
||||
for _ in range(w2.dim() - w1.dim()):
|
||||
w1 = w1.unsqueeze(-1)
|
||||
diff = torch.kron(w1, w2)
|
||||
return w + diff.reshape(w.shape).to(w)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoKr training: efficient Kronecker product.
|
||||
|
||||
Uses w1/w2 properties which handle both direct and decomposed cases.
|
||||
For create_train (direct w1/w2), no alpha scaling in properties.
|
||||
For to_train (decomposed), alpha/rank scaling is in properties.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
"""
|
||||
# Get w1, w2 from properties (handles rebuild vs direct)
|
||||
w1 = self.w1
|
||||
w2 = self.w2
|
||||
|
||||
# Multiplier from bypass injection
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
|
||||
# Get module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
# Efficient Kronecker application without materializing full weight
|
||||
# kron(w1, w2) @ x can be computed as nested operations
|
||||
# w1: [out_l, in_m], w2: [out_k, in_n, *k_size]
|
||||
# Full weight would be [out_l*out_k, in_m*in_n, *k_size]
|
||||
|
||||
uq = w1.size(1) # in_m - inner grouping dimension
|
||||
|
||||
if is_conv:
|
||||
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||
|
||||
B, C_in, *spatial = x.shape
|
||||
# Reshape input for grouped application: [B * uq, C_in // uq, *spatial]
|
||||
h_in_group = x.reshape(B * uq, -1, *spatial)
|
||||
|
||||
# Ensure w2 has conv dims
|
||||
if w2.dim() == 2:
|
||||
w2 = w2.view(*w2.shape, *([1] * conv_dim))
|
||||
|
||||
# Apply w2 path with stride/padding
|
||||
hb = conv_fn(h_in_group, w2, **kw_dict)
|
||||
|
||||
# Reshape for cross-group operation
|
||||
hb = hb.view(B, -1, *hb.shape[1:])
|
||||
h_cross = hb.transpose(1, -1)
|
||||
|
||||
# Apply w1 (always 2D, applied as linear on channel dim)
|
||||
hc = F.linear(h_cross, w1)
|
||||
hc = hc.transpose(1, -1)
|
||||
|
||||
# Reshape to output
|
||||
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||
else:
|
||||
# Linear case
|
||||
# Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n]
|
||||
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||
|
||||
# Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k]
|
||||
hb = F.linear(h_in_group, w2)
|
||||
|
||||
# Transpose for w1: [..., uq, out_k] -> [..., out_k, uq]
|
||||
h_cross = hb.transpose(-1, -2)
|
||||
|
||||
# Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l]
|
||||
hc = F.linear(h_cross, w1)
|
||||
|
||||
# Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k]
|
||||
hc = hc.transpose(-1, -2)
|
||||
out = hc.reshape(*hc.shape[:-2], -1)
|
||||
|
||||
return out * multiplier
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
@@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
out1, out2 = factorization(out_dim, rank)
|
||||
in1, in2 = factorization(in_dim, rank)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
|
||||
in_dim = weight.shape[1] # Just in_channels, not flattened with kernel
|
||||
k_size = weight.shape[2:] if weight.dim() > 2 else ()
|
||||
|
||||
out_l, out_k = factorization(out_dim, rank)
|
||||
in_m, in_n = factorization(in_dim, rank)
|
||||
|
||||
# w1: [out_l, in_m]
|
||||
mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32)
|
||||
# w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear
|
||||
mat2 = torch.empty(
|
||||
out_k, in_n, *k_size, device=weight.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||
torch.nn.init.constant_(mat1, 0.0)
|
||||
return LokrDiff(
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None))
|
||||
|
||||
def to_train(self):
|
||||
return LokrDiff(self.weights)
|
||||
@@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||
if (
|
||||
(lokr_w1 is not None)
|
||||
or (lokr_w2 is not None)
|
||||
or (lokr_w1_a is not None)
|
||||
or (lokr_w2_a is not None)
|
||||
):
|
||||
weights = (
|
||||
lokr_w1,
|
||||
lokr_w2,
|
||||
alpha,
|
||||
lokr_w1_a,
|
||||
lokr_w1_b,
|
||||
lokr_w2_a,
|
||||
lokr_w2_b,
|
||||
lokr_t2,
|
||||
dora_scale,
|
||||
)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
@@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
w1 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w1_a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1_b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
w1 = comfy.model_management.cast_to_device(
|
||||
w1, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
w2 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
w2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
comfy.model_management.cast_to_device(
|
||||
t2, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_b, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_a, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
w2 = comfy.model_management.cast_to_device(
|
||||
w2, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
@@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoKr: efficient Kronecker product application.
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
|
||||
Reference: LyCORIS functional/lokr.py bypass_forward_diff
|
||||
"""
|
||||
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
|
||||
v = self.weights
|
||||
# v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
alpha = v[2]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
|
||||
use_w1 = w1 is not None
|
||||
use_w2 = w2 is not None
|
||||
tucker = t2 is not None
|
||||
|
||||
# Use module info from bypass injection, not weight dimension
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {}) if is_conv else {}
|
||||
|
||||
if is_conv:
|
||||
op = FUNC_LIST[conv_dim + 2]
|
||||
else:
|
||||
op = F.linear
|
||||
|
||||
# Determine rank and scale
|
||||
rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha
|
||||
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
|
||||
self, "multiplier", 1.0
|
||||
)
|
||||
|
||||
# Build c (w1)
|
||||
if use_w1:
|
||||
c = w1.to(dtype=x.dtype)
|
||||
else:
|
||||
c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype)
|
||||
uq = c.size(1)
|
||||
|
||||
# Build w2 components
|
||||
if use_w2:
|
||||
ba = w2.to(dtype=x.dtype)
|
||||
else:
|
||||
a = w2_b.to(dtype=x.dtype)
|
||||
b = w2_a.to(dtype=x.dtype)
|
||||
if is_conv:
|
||||
if tucker:
|
||||
# Tucker: a, b get 1s appended (kernel is in t2)
|
||||
if a.dim() == 2:
|
||||
a = a.view(*a.shape, *([1] * conv_dim))
|
||||
if b.dim() == 2:
|
||||
b = b.view(*b.shape, *([1] * conv_dim))
|
||||
else:
|
||||
# Non-tucker conv: b may need 1s appended
|
||||
if b.dim() == 2:
|
||||
b = b.view(*b.shape, *([1] * conv_dim))
|
||||
|
||||
# Reshape input by uq groups
|
||||
if is_conv:
|
||||
B, _, *rest = x.shape
|
||||
h_in_group = x.reshape(B * uq, -1, *rest)
|
||||
else:
|
||||
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||
|
||||
# Apply w2 path
|
||||
if use_w2:
|
||||
hb = op(h_in_group, ba, **kw_dict)
|
||||
else:
|
||||
if is_conv:
|
||||
if tucker:
|
||||
t = t2.to(dtype=x.dtype)
|
||||
if t.dim() == 2:
|
||||
t = t.view(*t.shape, *([1] * conv_dim))
|
||||
ha = op(h_in_group, a)
|
||||
ht = op(ha, t, **kw_dict)
|
||||
hb = op(ht, b)
|
||||
else:
|
||||
ha = op(h_in_group, a, **kw_dict)
|
||||
hb = op(ha, b)
|
||||
else:
|
||||
ha = op(h_in_group, a)
|
||||
hb = op(ha, b)
|
||||
|
||||
# Reshape and apply c (w1)
|
||||
if is_conv:
|
||||
hb = hb.view(B, -1, *hb.shape[1:])
|
||||
h_cross_group = hb.transpose(1, -1)
|
||||
else:
|
||||
h_cross_group = hb.transpose(-1, -2)
|
||||
|
||||
hc = F.linear(h_cross_group, c)
|
||||
|
||||
if is_conv:
|
||||
hc = hc.transpose(1, -1)
|
||||
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||
else:
|
||||
hc = hc.transpose(-1, -2)
|
||||
out = hc.reshape(*hc.shape[:-2], -1)
|
||||
|
||||
return out * scale
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
@@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase):
|
||||
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
||||
if mid is not None:
|
||||
convdim = mid.ndim - 2
|
||||
layer = (
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d
|
||||
)[convdim]
|
||||
layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim]
|
||||
else:
|
||||
layer = torch.nn.Linear
|
||||
self.lora_up = layer(rank, out_dim, bias=False)
|
||||
@@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase):
|
||||
weight = w + scale * diff.reshape(w.shape)
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoRA training: h(x) = up(down(x)) * scale
|
||||
|
||||
Simple implementation using the nn.Module weights directly.
|
||||
No mid/dora/reshape branches (create_train doesn't create them).
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
"""
|
||||
# Compute scale = alpha / rank * multiplier
|
||||
scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0)
|
||||
|
||||
# Get module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
# Get weights (keep in original dtype for numerical stability)
|
||||
down_weight = self.lora_down.weight
|
||||
up_weight = self.lora_up.weight
|
||||
|
||||
if is_conv:
|
||||
# Conv path: use functional conv
|
||||
# conv_dim: 1=conv1d, 2=conv2d, 3=conv3d
|
||||
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||
|
||||
# Reshape 2D weights to conv format if needed
|
||||
# down: [rank, in_features] -> [rank, in_channels, *kernel_size]
|
||||
# up: [out_features, rank] -> [out_features, rank, 1, 1, ...]
|
||||
if down_weight.dim() == 2:
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
if in_channels is not None:
|
||||
down_weight = down_weight.view(
|
||||
down_weight.shape[0], in_channels, *kernel_size
|
||||
)
|
||||
else:
|
||||
# Fallback: assume 1x1 kernel
|
||||
down_weight = down_weight.view(
|
||||
*down_weight.shape, *([1] * conv_dim)
|
||||
)
|
||||
if up_weight.dim() == 2:
|
||||
# up always uses 1x1 kernel
|
||||
up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim))
|
||||
|
||||
# down conv uses stride/padding from module, up is 1x1
|
||||
hidden = conv_fn(x, down_weight, **kw_dict)
|
||||
|
||||
# mid layer if exists (tucker decomposition)
|
||||
if self.lora_mid is not None:
|
||||
mid_weight = self.lora_mid.weight
|
||||
if mid_weight.dim() == 2:
|
||||
mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim))
|
||||
hidden = conv_fn(hidden, mid_weight)
|
||||
|
||||
# up conv is always 1x1 (no stride/padding)
|
||||
out = conv_fn(hidden, up_weight)
|
||||
else:
|
||||
# Linear path: simple matmul chain
|
||||
hidden = F.linear(x, down_weight)
|
||||
|
||||
# mid layer if exists
|
||||
if self.lora_mid is not None:
|
||||
mid_weight = self.lora_mid.weight
|
||||
hidden = F.linear(hidden, mid_weight)
|
||||
|
||||
out = F.linear(hidden, up_weight)
|
||||
|
||||
return out * scale
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
@@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
return LoraDiff(
|
||||
(mat1, mat2, alpha, None, None, None)
|
||||
)
|
||||
return LoraDiff((mat1, mat2, alpha, None, None, None))
|
||||
|
||||
def to_train(self):
|
||||
return LoraDiff(self.weights)
|
||||
@@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoRA: h(x) = up(down(x)) * scale
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
|
||||
Reference: LyCORIS functional/locon.py bypass_forward_diff
|
||||
"""
|
||||
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
|
||||
v = self.weights
|
||||
# v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape
|
||||
up = v[0]
|
||||
down = v[1]
|
||||
alpha = v[2]
|
||||
mid = v[3]
|
||||
|
||||
# Compute scale = alpha / rank
|
||||
rank = down.shape[0]
|
||||
if alpha is not None:
|
||||
scale = alpha / rank
|
||||
else:
|
||||
scale = 1.0
|
||||
scale = scale * getattr(self, "multiplier", 1.0)
|
||||
|
||||
# Cast dtype
|
||||
up = up.to(dtype=x.dtype)
|
||||
down = down.to(dtype=x.dtype)
|
||||
|
||||
# Use module info from bypass injection, not weight dimension
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
if is_conv:
|
||||
op = FUNC_LIST[
|
||||
conv_dim + 2
|
||||
] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5)
|
||||
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||
in_channels = getattr(self, "in_channels", None)
|
||||
|
||||
# Reshape 2D weights to conv format using kernel_size
|
||||
# down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size]
|
||||
# up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel)
|
||||
if down.dim() == 2:
|
||||
# down.shape[1] = in_channels * prod(kernel_size)
|
||||
if in_channels is not None:
|
||||
down = down.view(down.shape[0], in_channels, *kernel_size)
|
||||
else:
|
||||
# Fallback: assume 1x1 kernel if in_channels unknown
|
||||
down = down.view(*down.shape, *([1] * conv_dim))
|
||||
if up.dim() == 2:
|
||||
# up always uses 1x1 kernel
|
||||
up = up.view(*up.shape, *([1] * conv_dim))
|
||||
if mid is not None:
|
||||
mid = mid.to(dtype=x.dtype)
|
||||
if mid.dim() == 2:
|
||||
mid = mid.view(*mid.shape, *([1] * conv_dim))
|
||||
else:
|
||||
op = F.linear
|
||||
kw_dict = {} # linear doesn't take stride/padding
|
||||
|
||||
# Simple chain: down -> mid (if tucker) -> up
|
||||
if mid is not None:
|
||||
if not is_conv:
|
||||
mid = mid.to(dtype=x.dtype)
|
||||
hidden = op(x, down)
|
||||
hidden = op(hidden, mid, **kw_dict)
|
||||
out = op(hidden, up)
|
||||
else:
|
||||
hidden = op(x, down, **kw_dict)
|
||||
out = op(hidden, up)
|
||||
|
||||
return out * scale
|
||||
|
||||
@@ -3,13 +3,18 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
factorization,
|
||||
)
|
||||
|
||||
|
||||
class OFTDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
# Unpack weights tuple from OFTAdapter
|
||||
blocks, rescale, alpha, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
@@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase):
|
||||
weight = self.rescale * weight
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def _get_orthogonal_matrix(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||
blocks = self.oft_blocks.to(device=device, dtype=dtype)
|
||||
I = torch.eye(self.block_size, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if set
|
||||
if self.constraint:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > self.constraint:
|
||||
normed_q = q * self.constraint / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r.to(dtype)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
OFT has no additive component - returns zeros matching base_out shape.
|
||||
|
||||
OFT only transforms the output via g(), it doesn't add to it.
|
||||
"""
|
||||
return torch.zeros_like(base_out)
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for OFT: applies orthogonal rotation.
|
||||
|
||||
OFT transforms output channels using block-diagonal orthogonal matrices.
|
||||
"""
|
||||
r = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||
|
||||
# Apply multiplier to interpolate between identity and full transform
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(self.block_size, device=y.device, dtype=y.dtype)
|
||||
r = r * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Use module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# y now has channels in last dim
|
||||
*batch_shape, out_features = y.shape
|
||||
|
||||
# Reshape to apply block-diagonal transform
|
||||
# (*, out_features) -> (*, block_num, block_size)
|
||||
y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size)
|
||||
|
||||
# Apply orthogonal transform: R @ y for each block
|
||||
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||
|
||||
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||
out = out_blocked.reshape(*batch_shape, out_features)
|
||||
|
||||
# Apply rescale if present
|
||||
if self.rescaled:
|
||||
rescale = self.rescale.to(device=y.device, dtype=y.dtype)
|
||||
out = out * rescale.view(-1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
out = out.transpose(1, -1)
|
||||
|
||||
return out
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
@@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase):
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
block_size, block_num = factorization(out_dim, rank)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
|
||||
return OFTDiff(
|
||||
(block, None, alpha, None)
|
||||
block = torch.zeros(
|
||||
block_num, block_size, block_size, device=weight.device, dtype=torch.float32
|
||||
)
|
||||
return OFTDiff((block, None, alpha, None))
|
||||
|
||||
def to_train(self):
|
||||
return OFTDiff(self.weights)
|
||||
@@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase):
|
||||
alpha = 0
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
blocks = comfy.model_management.cast_to_device(
|
||||
blocks, weight.device, intermediate_dtype
|
||||
)
|
||||
if rescale is not None:
|
||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||
rescale = comfy.model_management.cast_to_device(
|
||||
rescale, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
block_num, block_size, *_ = blocks.shape
|
||||
|
||||
@@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase):
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
if alpha > 0: # alpha in oft/boft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(weight)
|
||||
# Create I in weight's dtype for the einsum
|
||||
I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype)
|
||||
_, *shape = weight.shape
|
||||
lora_diff = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
(r * strength) - strength * I,
|
||||
(r * strength) - strength * I_w,
|
||||
weight.view(block_num, block_size, *shape),
|
||||
).view(-1, *shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def _get_orthogonal_matrix(self, device, dtype):
|
||||
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||
v = self.weights
|
||||
blocks = v[0].to(device=device, dtype=dtype)
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
|
||||
block_num, block_size, _ = blocks.shape
|
||||
I = torch.eye(block_size, device=device, dtype=dtype)
|
||||
|
||||
# Q = blocks - blocks^T (skew-symmetric)
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
|
||||
# Apply constraint if alpha > 0
|
||||
if alpha > 0:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > alpha:
|
||||
normed_q = q * alpha / q_norm
|
||||
|
||||
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
return r, block_num, block_size
|
||||
|
||||
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Output transformation for OFT: applies orthogonal rotation to output.
|
||||
|
||||
OFT transforms the output channels using block-diagonal orthogonal matrices.
|
||||
|
||||
Reference: LyCORIS DiagOFTModule._bypass_forward
|
||||
"""
|
||||
v = self.weights
|
||||
rescale = v[1]
|
||||
|
||||
r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||
|
||||
# Apply multiplier to interpolate between identity and full transform
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
I = torch.eye(block_size, device=y.device, dtype=y.dtype)
|
||||
r = r * multiplier + (1 - multiplier) * I
|
||||
|
||||
# Use module info from bypass injection to determine conv vs linear
|
||||
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||
|
||||
if is_conv:
|
||||
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||
y = y.transpose(1, -1)
|
||||
|
||||
# y now has channels in last dim
|
||||
*batch_shape, out_features = y.shape
|
||||
|
||||
# Reshape to apply block-diagonal transform
|
||||
# (*, out_features) -> (*, block_num, block_size)
|
||||
y_blocked = y.view(*batch_shape, block_num, block_size)
|
||||
|
||||
# Apply orthogonal transform: R @ y for each block
|
||||
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||
|
||||
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||
out = out_blocked.view(*batch_shape, out_features)
|
||||
|
||||
# Apply rescale if present
|
||||
if rescale is not None:
|
||||
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
||||
out = out * rescale.view(-1)
|
||||
|
||||
if is_conv:
|
||||
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||
out = out.transpose(1, -1)
|
||||
|
||||
return out
|
||||
|
||||
52
comfy/windows.py
Normal file
52
comfy/windows.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import psutil
|
||||
from ctypes import wintypes
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
psapi = ctypes.WinDLL("psapi")
|
||||
kernel32 = ctypes.WinDLL("kernel32")
|
||||
|
||||
class PERFORMANCE_INFORMATION(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("cb", wintypes.DWORD),
|
||||
("CommitTotal", ctypes.c_size_t),
|
||||
("CommitLimit", ctypes.c_size_t),
|
||||
("CommitPeak", ctypes.c_size_t),
|
||||
("PhysicalTotal", ctypes.c_size_t),
|
||||
("PhysicalAvailable", ctypes.c_size_t),
|
||||
("SystemCache", ctypes.c_size_t),
|
||||
("KernelTotal", ctypes.c_size_t),
|
||||
("KernelPaged", ctypes.c_size_t),
|
||||
("KernelNonpaged", ctypes.c_size_t),
|
||||
("PageSize", ctypes.c_size_t),
|
||||
("HandleCount", wintypes.DWORD),
|
||||
("ProcessCount", wintypes.DWORD),
|
||||
("ThreadCount", wintypes.DWORD),
|
||||
]
|
||||
|
||||
def get_free_ram():
|
||||
#Windows is way too conservative and chalks recently used uncommitted model RAM
|
||||
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
|
||||
#
|
||||
#1: What psutil says
|
||||
#2: Total Memory - (Committed Memory - VRAM in use)
|
||||
#
|
||||
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
|
||||
#commit charge for all VRAM used just incase it wants to page it all out. This just
|
||||
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
|
||||
|
||||
pi = PERFORMANCE_INFORMATION()
|
||||
pi.cb = ctypes.sizeof(pi)
|
||||
|
||||
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
|
||||
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
committed = pi.CommitTotal * pi.PageSize
|
||||
total = pi.PhysicalTotal * pi.PageSize
|
||||
|
||||
return max(psutil.virtual_memory().available,
|
||||
total - (committed - comfy_aimdo.control.get_total_vram_usage()))
|
||||
|
||||
@@ -7,10 +7,9 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from . import _node_replace_public as node_replace
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
@@ -106,6 +105,7 @@ class Types:
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
File3D = File3D
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
@@ -131,5 +131,4 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@@ -34,6 +34,21 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = False,
|
||||
) -> VideoInput | None:
|
||||
"""
|
||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||
|
||||
Returns:
|
||||
A new VideoInput, or None if the result would have negative duration
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user