mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-14 09:38:05 +00:00
Compare commits
151 Commits
pysssss/ba
...
toolkit/wi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad5b8ca494 | ||
|
|
07ca6852e8 | ||
|
|
f266b8d352 | ||
|
|
b6cb30bab5 | ||
|
|
ee72752162 | ||
|
|
7591d781a7 | ||
|
|
0bfb936ab4 | ||
|
|
602b2505a4 | ||
|
|
04a55d5019 | ||
|
|
5fb8f06495 | ||
|
|
5a182bfaf1 | ||
|
|
f394af8d0f | ||
|
|
aeb5bdc8f6 | ||
|
|
64953bda0a | ||
|
|
b254cecd03 | ||
|
|
1bb956fb66 | ||
|
|
96d6bd1a4a | ||
|
|
5f2117528a | ||
|
|
0301ccf745 | ||
|
|
4d172e9ad7 | ||
|
|
5632b2df9d | ||
|
|
2687652530 | ||
|
|
6d11cc7354 | ||
|
|
f262444dd4 | ||
|
|
239ddd3327 | ||
|
|
83dd65f23a | ||
|
|
8ad38d2073 | ||
|
|
6c14f129af | ||
|
|
58dcc97dcf | ||
|
|
19236edfa4 | ||
|
|
73c3f86973 | ||
|
|
262abf437b | ||
|
|
5284e6bf69 | ||
|
|
44f8598521 | ||
|
|
fe52843fe5 | ||
|
|
c39653163d | ||
|
|
18927538a1 | ||
|
|
8a6fbc2dc2 | ||
|
|
b44fc4c589 | ||
|
|
4454fab7f0 | ||
|
|
1978f59ffd | ||
|
|
88e6370527 | ||
|
|
c0370044cd | ||
|
|
ecd2a19661 | ||
|
|
2c1d06a4e3 | ||
|
|
e2c71ceb00 | ||
|
|
596ed68691 | ||
|
|
ce4a1ab48d | ||
|
|
e1ede29d82 | ||
|
|
df1e5e8514 | ||
|
|
dc9822b7df | ||
|
|
712efb466b | ||
|
|
726af73867 | ||
|
|
831351a29e | ||
|
|
e1add563f9 | ||
|
|
8902907d7a | ||
|
|
e03fe8b591 | ||
|
|
ae79e33345 | ||
|
|
117e214354 | ||
|
|
4a93a62371 | ||
|
|
66c18522fb | ||
|
|
e5ae670a40 | ||
|
|
3fe61cedda | ||
|
|
2a4328d639 | ||
|
|
d297a749a2 | ||
|
|
2b7cc7e3b6 | ||
|
|
4993411fd9 | ||
|
|
2c7cef4a23 | ||
|
|
76a7fa96db | ||
|
|
cdcf4119b3 | ||
|
|
dbe70b6821 | ||
|
|
00fff6019e | ||
|
|
123a7874a9 | ||
|
|
f719f9c062 | ||
|
|
fe053ba5eb | ||
|
|
6648ab68bc | ||
|
|
6615db925c | ||
|
|
8ca842a8ed | ||
|
|
c1b63a7e78 | ||
|
|
349a636a2b | ||
|
|
a4be04c5d7 | ||
|
|
baf8c87455 | ||
|
|
62315fbb15 | ||
|
|
a0302cc6a8 | ||
|
|
f350a84261 | ||
|
|
3760d74005 | ||
|
|
9bf5aa54db | ||
|
|
5ff4fdedba | ||
|
|
17e7df43d1 | ||
|
|
039955c527 | ||
|
|
6a26328842 | ||
|
|
204e65b8dc | ||
|
|
a831c19b70 | ||
|
|
eba6c940fd | ||
|
|
a1c101f861 | ||
|
|
c2d7f07dbf | ||
|
|
458292fef0 | ||
|
|
6555dc65b8 | ||
|
|
2b70ab9ad0 | ||
|
|
00efcc6cd0 | ||
|
|
cb459573c8 | ||
|
|
35183543e0 | ||
|
|
a246cc02b2 | ||
|
|
a50c32d63f | ||
|
|
6125b80979 | ||
|
|
c8fcbd66ee | ||
|
|
26dd7eb421 | ||
|
|
e77b34dfea | ||
|
|
ef73070ea4 | ||
|
|
d30c609f5a | ||
|
|
5087f1d497 | ||
|
|
a31681564d | ||
|
|
855849c658 | ||
|
|
fe2511468d | ||
|
|
3be0175166 | ||
|
|
b8315e66cb | ||
|
|
ab1050bec3 | ||
|
|
fb23935c11 | ||
|
|
85fc35e8fa | ||
|
|
223364743c | ||
|
|
affe881354 | ||
|
|
f5030e26fd | ||
|
|
66e1b07402 | ||
|
|
be4345d1c9 | ||
|
|
3c1a1a2df8 | ||
|
|
ba5bf3f1a8 | ||
|
|
c05a08ae66 | ||
|
|
de9ada6a41 | ||
|
|
37f711d4a1 | ||
|
|
dd86b15521 | ||
|
|
021ba20719 | ||
|
|
b60be02aaf | ||
|
|
2b5da3b72e | ||
|
|
794d05bdb1 | ||
|
|
361b9a82a3 | ||
|
|
667a1b8878 | ||
|
|
32621c6a11 | ||
|
|
f8acd9c402 | ||
|
|
873de5f37a | ||
|
|
aa6f7a83bb | ||
|
|
6ea8c128a3 | ||
|
|
6e469a3f35 | ||
|
|
b8f848bfe3 | ||
|
|
4064062e7d | ||
|
|
8aabe2403e | ||
|
|
0167653781 | ||
|
|
0a7993729c | ||
|
|
bbe2c13a70 | ||
|
|
3aace5c8dc | ||
|
|
b0d9708974 | ||
|
|
c9b633d84f |
127
.coderabbit.yaml
Normal file
127
.coderabbit.yaml
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||||
|
language: "en-US"
|
||||||
|
early_access: false
|
||||||
|
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||||
|
|
||||||
|
reviews:
|
||||||
|
profile: "chill"
|
||||||
|
request_changes_workflow: false
|
||||||
|
high_level_summary: false
|
||||||
|
poem: false
|
||||||
|
review_status: false
|
||||||
|
review_details: false
|
||||||
|
commit_status: true
|
||||||
|
collapse_walkthrough: true
|
||||||
|
changed_files_summary: false
|
||||||
|
sequence_diagrams: false
|
||||||
|
estimate_code_review_effort: false
|
||||||
|
assess_linked_issues: false
|
||||||
|
related_issues: false
|
||||||
|
related_prs: false
|
||||||
|
suggested_labels: false
|
||||||
|
auto_apply_labels: false
|
||||||
|
suggested_reviewers: false
|
||||||
|
auto_assign_reviewers: false
|
||||||
|
in_progress_fortune: false
|
||||||
|
enable_prompt_for_ai_agents: true
|
||||||
|
|
||||||
|
path_filters:
|
||||||
|
- "!comfy_api_nodes/apis/**"
|
||||||
|
- "!**/generated/*.pyi"
|
||||||
|
- "!.ci/**"
|
||||||
|
- "!script_examples/**"
|
||||||
|
- "!**/__pycache__/**"
|
||||||
|
- "!**/*.ipynb"
|
||||||
|
- "!**/*.png"
|
||||||
|
- "!**/*.bat"
|
||||||
|
|
||||||
|
path_instructions:
|
||||||
|
- path: "**"
|
||||||
|
instructions: |
|
||||||
|
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||||
|
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||||
|
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||||
|
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||||
|
treat it as unchanged. Contributors should not feel obligated to address
|
||||||
|
pre-existing issues outside the scope of their contribution.
|
||||||
|
- path: "comfy/**"
|
||||||
|
instructions: |
|
||||||
|
Core ML/diffusion engine. Focus on:
|
||||||
|
- Backward compatibility (breaking changes affect all custom nodes)
|
||||||
|
- Memory management and GPU resource handling
|
||||||
|
- Performance implications in hot paths
|
||||||
|
- Thread safety for concurrent execution
|
||||||
|
- path: "comfy_api_nodes/**"
|
||||||
|
instructions: |
|
||||||
|
Third-party API integration nodes. Focus on:
|
||||||
|
- No hardcoded API keys or secrets
|
||||||
|
- Proper error handling for API failures (timeouts, rate limits, auth errors)
|
||||||
|
- Correct Pydantic model usage
|
||||||
|
- Security of user data passed to external APIs
|
||||||
|
- path: "comfy_extras/**"
|
||||||
|
instructions: |
|
||||||
|
Community-contributed extra nodes. Focus on:
|
||||||
|
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
|
||||||
|
- No breaking changes to existing node interfaces
|
||||||
|
- path: "comfy_execution/**"
|
||||||
|
instructions: |
|
||||||
|
Execution engine (graph execution, caching, jobs). Focus on:
|
||||||
|
- Caching correctness
|
||||||
|
- Concurrent execution safety
|
||||||
|
- Graph validation edge cases
|
||||||
|
- path: "nodes.py"
|
||||||
|
instructions: |
|
||||||
|
Core node definitions (2500+ lines). Focus on:
|
||||||
|
- Backward compatibility of NODE_CLASS_MAPPINGS
|
||||||
|
- Consistency of INPUT_TYPES return format
|
||||||
|
- path: "alembic_db/**"
|
||||||
|
instructions: |
|
||||||
|
Database migrations. Focus on:
|
||||||
|
- Migration safety and rollback support
|
||||||
|
- Data preservation during schema changes
|
||||||
|
|
||||||
|
auto_review:
|
||||||
|
enabled: true
|
||||||
|
auto_incremental_review: true
|
||||||
|
drafts: false
|
||||||
|
ignore_title_keywords:
|
||||||
|
- "WIP"
|
||||||
|
- "DO NOT REVIEW"
|
||||||
|
- "DO NOT MERGE"
|
||||||
|
|
||||||
|
finishing_touches:
|
||||||
|
docstrings:
|
||||||
|
enabled: false
|
||||||
|
unit_tests:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
tools:
|
||||||
|
ruff:
|
||||||
|
enabled: false
|
||||||
|
pylint:
|
||||||
|
enabled: false
|
||||||
|
flake8:
|
||||||
|
enabled: false
|
||||||
|
gitleaks:
|
||||||
|
enabled: true
|
||||||
|
shellcheck:
|
||||||
|
enabled: false
|
||||||
|
markdownlint:
|
||||||
|
enabled: false
|
||||||
|
yamllint:
|
||||||
|
enabled: false
|
||||||
|
languagetool:
|
||||||
|
enabled: false
|
||||||
|
github-checks:
|
||||||
|
enabled: true
|
||||||
|
timeout_ms: 90000
|
||||||
|
ast-grep:
|
||||||
|
essential_rules: true
|
||||||
|
|
||||||
|
chat:
|
||||||
|
auto_reply: true
|
||||||
|
|
||||||
|
knowledge_base:
|
||||||
|
opt_out: false
|
||||||
|
learnings:
|
||||||
|
scope: "auto"
|
||||||
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@@ -7,6 +7,8 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
send-webhook:
|
send-webhook:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
- name: Send release webhook
|
- name: Send release webhook
|
||||||
env:
|
env:
|
||||||
@@ -106,3 +108,37 @@ jobs:
|
|||||||
--fail --silent --show-error
|
--fail --silent --show-error
|
||||||
|
|
||||||
echo "✅ Release webhook sent successfully"
|
echo "✅ Release webhook sent successfully"
|
||||||
|
|
||||||
|
- name: Send repository dispatch to desktop
|
||||||
|
env:
|
||||||
|
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||||
|
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||||
|
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||||
|
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
PAYLOAD="$(jq -n \
|
||||||
|
--arg release_tag "$RELEASE_TAG" \
|
||||||
|
--arg release_url "$RELEASE_URL" \
|
||||||
|
'{
|
||||||
|
event_type: "comfyui_release_published",
|
||||||
|
client_payload: {
|
||||||
|
release_tag: $release_tag,
|
||||||
|
release_url: $release_url
|
||||||
|
}
|
||||||
|
}')"
|
||||||
|
|
||||||
|
curl -fsSL \
|
||||||
|
-X POST \
|
||||||
|
-H "Accept: application/vnd.github+json" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||||
|
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||||
|
-d "$PAYLOAD"
|
||||||
|
|
||||||
|
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "9"
|
default: "11"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,7 +11,7 @@ extra_model_paths.yaml
|
|||||||
/.vs
|
/.vs
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv*/
|
||||||
.venv/
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
|
|||||||
@@ -227,11 +227,11 @@ Put your VAE in: models/vae
|
|||||||
|
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
||||||
|
|
||||||
|
|
||||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
import urllib.parse
|
||||||
|
import os
|
||||||
|
import contextlib
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@@ -8,6 +11,9 @@ import app.assets.manager as manager
|
|||||||
from app import user_manager
|
from app import user_manager
|
||||||
from app.assets.api import schemas_in
|
from app.assets.api import schemas_in
|
||||||
from app.assets.helpers import get_query_dict
|
from app.assets.helpers import get_query_dict
|
||||||
|
from app.assets.scanner import seed_assets
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
USER_MANAGER: user_manager.UserManager | None = None
|
USER_MANAGER: user_manager.UserManager | None = None
|
||||||
@@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
|
|||||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||||
|
|
||||||
|
# Note to any custom node developers reading this code:
|
||||||
|
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||||
|
|
||||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||||
global USER_MANAGER
|
global USER_MANAGER
|
||||||
USER_MANAGER = user_manager_instance
|
USER_MANAGER = user_manager_instance
|
||||||
@@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
|||||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.head("/api/assets/hash/{hash}")
|
||||||
|
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||||
|
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||||
|
if not hash_str or ":" not in hash_str:
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
algo, digest = hash_str.split(":", 1)
|
||||||
|
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
exists = manager.asset_exists(asset_hash=hash_str)
|
||||||
|
return web.Response(status=200 if exists else 404)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get("/api/assets")
|
@ROUTES.get("/api/assets")
|
||||||
async def list_assets(request: web.Request) -> web.Response:
|
async def list_assets(request: web.Request) -> web.Response:
|
||||||
"""
|
"""
|
||||||
@@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
|||||||
order=q.order,
|
order=q.order,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return web.json_response(payload.model_dump(mode="json"))
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
|
|||||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||||
|
async def download_asset_content(request: web.Request) -> web.Response:
|
||||||
|
# question: do we need disposition? could we just stick with one of these?
|
||||||
|
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||||
|
if disposition not in {"inline", "attachment"}:
|
||||||
|
disposition = "attachment"
|
||||||
|
|
||||||
|
try:
|
||||||
|
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||||
|
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||||
|
except NotImplementedError as nie:
|
||||||
|
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||||
|
except FileNotFoundError:
|
||||||
|
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||||
|
|
||||||
|
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||||
|
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||||
|
|
||||||
|
file_size = os.path.getsize(abs_path)
|
||||||
|
logging.info(
|
||||||
|
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||||
|
abs_path,
|
||||||
|
file_size,
|
||||||
|
file_size / (1024 * 1024),
|
||||||
|
content_type,
|
||||||
|
filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def file_sender():
|
||||||
|
chunk_size = 64 * 1024
|
||||||
|
with open(abs_path, "rb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = f.read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
body=file_sender(),
|
||||||
|
content_type=content_type,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": cd,
|
||||||
|
"Content-Length": str(file_size),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets/from-hash")
|
||||||
|
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
result = manager.create_asset_from_hash(
|
||||||
|
hash_str=body.hash,
|
||||||
|
name=body.name,
|
||||||
|
tags=body.tags,
|
||||||
|
user_metadata=body.user_metadata,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets")
|
||||||
|
async def upload_asset(request: web.Request) -> web.Response:
|
||||||
|
"""Multipart/form-data endpoint for Asset uploads."""
|
||||||
|
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||||
|
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||||
|
|
||||||
|
reader = await request.multipart()
|
||||||
|
|
||||||
|
file_present = False
|
||||||
|
file_client_name: str | None = None
|
||||||
|
tags_raw: list[str] = []
|
||||||
|
provided_name: str | None = None
|
||||||
|
user_metadata_raw: str | None = None
|
||||||
|
provided_hash: str | None = None
|
||||||
|
provided_hash_exists: bool | None = None
|
||||||
|
|
||||||
|
file_written = 0
|
||||||
|
tmp_path: str | None = None
|
||||||
|
while True:
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
fname = getattr(field, "name", "") or ""
|
||||||
|
|
||||||
|
if fname == "hash":
|
||||||
|
try:
|
||||||
|
s = ((await field.text()) or "").strip().lower()
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
|
||||||
|
if s:
|
||||||
|
if ":" not in s:
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
algo, digest = s.split(":", 1)
|
||||||
|
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
provided_hash = f"{algo}:{digest}"
|
||||||
|
try:
|
||||||
|
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||||
|
except Exception:
|
||||||
|
provided_hash_exists = None # do not fail the whole request here
|
||||||
|
|
||||||
|
elif fname == "file":
|
||||||
|
file_present = True
|
||||||
|
file_client_name = (field.filename or "").strip()
|
||||||
|
|
||||||
|
if provided_hash and provided_hash_exists is True:
|
||||||
|
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
file_written += len(chunk)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||||
|
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||||
|
|
||||||
|
# Otherwise, store to temp for hashing/ingest
|
||||||
|
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||||
|
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||||
|
os.makedirs(unique_dir, exist_ok=True)
|
||||||
|
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(tmp_path, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
file_written += len(chunk)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
if os.path.exists(tmp_path or ""):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||||
|
elif fname == "tags":
|
||||||
|
tags_raw.append((await field.text()) or "")
|
||||||
|
elif fname == "name":
|
||||||
|
provided_name = (await field.text()) or None
|
||||||
|
elif fname == "user_metadata":
|
||||||
|
user_metadata_raw = (await field.text()) or None
|
||||||
|
|
||||||
|
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||||
|
if not file_present and not (provided_hash and provided_hash_exists):
|
||||||
|
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||||
|
|
||||||
|
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||||
|
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||||
|
try:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||||
|
"tags": tags_raw,
|
||||||
|
"name": provided_name,
|
||||||
|
"user_metadata": user_metadata_raw,
|
||||||
|
"hash": provided_hash,
|
||||||
|
})
|
||||||
|
except ValidationError as ve:
|
||||||
|
try:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
|
||||||
|
# Validate models category against configured folders (consistent with previous behavior)
|
||||||
|
if spec.tags and spec.tags[0] == "models":
|
||||||
|
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
return _error_response(
|
||||||
|
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||||
|
|
||||||
|
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||||
|
if spec.hash and provided_hash_exists is True:
|
||||||
|
try:
|
||||||
|
result = manager.create_asset_from_hash(
|
||||||
|
hash_str=spec.hash,
|
||||||
|
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||||
|
tags=spec.tags,
|
||||||
|
user_metadata=spec.user_metadata or {},
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||||
|
|
||||||
|
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
status = 200 if (not result.created_new) else 201
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||||
|
|
||||||
|
# Otherwise, we must have a temp file path to ingest
|
||||||
|
if not tmp_path or not os.path.exists(tmp_path):
|
||||||
|
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
created = manager.upload_asset_from_temp_path(
|
||||||
|
spec,
|
||||||
|
temp_path=tmp_path,
|
||||||
|
client_filename=file_client_name,
|
||||||
|
owner_id=owner_id,
|
||||||
|
expected_asset_hash=spec.hash,
|
||||||
|
)
|
||||||
|
status = 201 if created.created_new else 200
|
||||||
|
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||||
|
except ValueError as e:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
msg = str(e)
|
||||||
|
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||||
|
return _error_response(
|
||||||
|
400,
|
||||||
|
"HASH_MISMATCH",
|
||||||
|
"Uploaded file hash does not match provided hash.",
|
||||||
|
)
|
||||||
|
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||||
|
except Exception:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
|
async def update_asset(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = manager.update_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
name=body.name,
|
||||||
|
user_metadata=body.user_metadata,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except (ValueError, PermissionError) as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
|
async def delete_asset(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
delete_content = request.query.get("delete_content")
|
||||||
|
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = manager.delete_asset_reference(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
delete_content_if_orphan=delete_content,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||||
|
return web.Response(status=204)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get("/api/tags")
|
@ROUTES.get("/api/tags")
|
||||||
async def get_tags(request: web.Request) -> web.Response:
|
async def get_tags(request: web.Request) -> web.Response:
|
||||||
"""
|
"""
|
||||||
@@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
|
|||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return web.json_response(result.model_dump(mode="json"))
|
return web.json_response(result.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
|
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
data = schemas_in.TagsAdd.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = manager.add_tags_to_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=data.tags,
|
||||||
|
origin="manual",
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except (ValueError, PermissionError) as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
|
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
data = schemas_in.TagsRemove.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = manager.remove_tags_from_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=data.tags,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets/seed")
|
||||||
|
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||||
|
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
roots = payload.get("roots", ["models", "input", "output"])
|
||||||
|
except Exception:
|
||||||
|
roots = ["models", "input", "output"]
|
||||||
|
|
||||||
|
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||||
|
if not valid_roots:
|
||||||
|
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||||
|
|
||||||
|
try:
|
||||||
|
seed_assets(tuple(valid_roots))
|
||||||
|
except Exception:
|
||||||
|
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||||
|
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||||
|
|
||||||
|
return web.json_response({"seeded": valid_roots}, status=200)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@@ -8,9 +7,9 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
conint,
|
conint,
|
||||||
field_validator,
|
field_validator,
|
||||||
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ListAssetsQuery(BaseModel):
|
class ListAssetsQuery(BaseModel):
|
||||||
include_tags: list[str] = Field(default_factory=list)
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
exclude_tags: list[str] = Field(default_factory=list)
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
@@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAssetBody(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
user_metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _at_least_one(self):
|
||||||
|
if self.name is None and self.user_metadata is None:
|
||||||
|
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class CreateFromHashBody(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
hash: str
|
||||||
|
name: str
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@field_validator("hash")
|
||||||
|
@classmethod
|
||||||
|
def _require_blake3(cls, v):
|
||||||
|
s = (v or "").strip().lower()
|
||||||
|
if ":" not in s:
|
||||||
|
raise ValueError("hash must be 'blake3:<hex>'")
|
||||||
|
algo, digest = s.split(":", 1)
|
||||||
|
if algo != "blake3":
|
||||||
|
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||||
|
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
raise ValueError("hash digest must be lowercase hex")
|
||||||
|
return s
|
||||||
|
|
||||||
|
@field_validator("tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _tags_norm(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, list):
|
||||||
|
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||||
|
seen = set()
|
||||||
|
dedup = []
|
||||||
|
for t in out:
|
||||||
|
if t not in seen:
|
||||||
|
seen.add(t)
|
||||||
|
dedup.append(t)
|
||||||
|
return dedup
|
||||||
|
if isinstance(v, str):
|
||||||
|
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class TagsListQuery(BaseModel):
|
class TagsListQuery(BaseModel):
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
@@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
|
|||||||
return v.lower() or None
|
return v.lower() or None
|
||||||
|
|
||||||
|
|
||||||
class SetPreviewBody(BaseModel):
|
class TagsAdd(BaseModel):
|
||||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
model_config = ConfigDict(extra="ignore")
|
||||||
preview_id: str | None = None
|
tags: list[str] = Field(..., min_length=1)
|
||||||
|
|
||||||
@field_validator("preview_id", mode="before")
|
@field_validator("tags")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _norm_uuid(cls, v):
|
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||||
|
out = []
|
||||||
|
for t in v:
|
||||||
|
if not isinstance(t, str):
|
||||||
|
raise TypeError("tags must be strings")
|
||||||
|
tnorm = t.strip().lower()
|
||||||
|
if tnorm:
|
||||||
|
out.append(tnorm)
|
||||||
|
seen = set()
|
||||||
|
deduplicated = []
|
||||||
|
for x in out:
|
||||||
|
if x not in seen:
|
||||||
|
seen.add(x)
|
||||||
|
deduplicated.append(x)
|
||||||
|
return deduplicated
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRemove(TagsAdd):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UploadAssetSpec(BaseModel):
|
||||||
|
"""Upload Asset operation.
|
||||||
|
- tags: ordered; first is root ('models'|'input'|'output');
|
||||||
|
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||||
|
- name: display name
|
||||||
|
- user_metadata: arbitrary JSON object (optional)
|
||||||
|
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||||
|
|
||||||
|
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||||
|
and the original extension is preserved when available.
|
||||||
|
"""
|
||||||
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
tags: list[str] = Field(..., min_length=1)
|
||||||
|
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
hash: str | None = Field(default=None)
|
||||||
|
|
||||||
|
@field_validator("hash", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_hash(cls, v):
|
||||||
if v is None:
|
if v is None:
|
||||||
return None
|
return None
|
||||||
s = str(v).strip()
|
s = str(v).strip().lower()
|
||||||
if not s:
|
if not s:
|
||||||
return None
|
return None
|
||||||
try:
|
if ":" not in s:
|
||||||
uuid.UUID(s)
|
raise ValueError("hash must be 'blake3:<hex>'")
|
||||||
except Exception:
|
algo, digest = s.split(":", 1)
|
||||||
raise ValueError("preview_id must be a UUID")
|
if algo != "blake3":
|
||||||
return s
|
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||||
|
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
raise ValueError("hash digest must be lowercase hex")
|
||||||
|
return f"{algo}:{digest}"
|
||||||
|
|
||||||
|
@field_validator("tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_tags(cls, v):
|
||||||
|
"""
|
||||||
|
Accepts a list of strings (possibly multiple form fields),
|
||||||
|
where each string can be:
|
||||||
|
- JSON array (e.g., '["models","loras","foo"]')
|
||||||
|
- comma-separated ('models, loras, foo')
|
||||||
|
- single token ('models')
|
||||||
|
Returns a normalized, deduplicated, ordered list.
|
||||||
|
"""
|
||||||
|
items: list[str] = []
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
v = [v]
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
for item in v:
|
||||||
|
if item is None:
|
||||||
|
continue
|
||||||
|
s = str(item).strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
if s.startswith("["):
|
||||||
|
try:
|
||||||
|
arr = json.loads(s)
|
||||||
|
if isinstance(arr, list):
|
||||||
|
items.extend(str(x) for x in arr)
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass # fallback to CSV parse below
|
||||||
|
items.extend([p for p in s.split(",") if p.strip()])
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# normalize + dedupe
|
||||||
|
norm = []
|
||||||
|
seen = set()
|
||||||
|
for t in items:
|
||||||
|
tnorm = str(t).strip().lower()
|
||||||
|
if tnorm and tnorm not in seen:
|
||||||
|
seen.add(tnorm)
|
||||||
|
norm.append(tnorm)
|
||||||
|
return norm
|
||||||
|
|
||||||
|
@field_validator("user_metadata", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_metadata_json(cls, v):
|
||||||
|
if v is None or isinstance(v, dict):
|
||||||
|
return v or {}
|
||||||
|
if isinstance(v, str):
|
||||||
|
s = v.strip()
|
||||||
|
if not s:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(s)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("user_metadata must be a JSON object")
|
||||||
|
return parsed
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_order(self):
|
||||||
|
if not self.tags:
|
||||||
|
raise ValueError("tags must be provided and non-empty")
|
||||||
|
root = self.tags[0]
|
||||||
|
if root not in {"models", "input", "output"}:
|
||||||
|
raise ValueError("first tag must be one of: models, input, output")
|
||||||
|
if root == "models":
|
||||||
|
if len(self.tags) < 2:
|
||||||
|
raise ValueError("models uploads require a category tag as the second tag")
|
||||||
|
return self
|
||||||
|
|||||||
@@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
|||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AssetUpdated(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
asset_hash: str | None = None
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
updated_at: datetime | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("updated_at")
|
||||||
|
def _ser_updated(self, v: datetime | None, _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
class AssetDetail(BaseModel):
|
class AssetDetail(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
|||||||
return v.isoformat() if v else None
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetCreated(AssetDetail):
|
||||||
|
created_new: bool
|
||||||
|
|
||||||
|
|
||||||
class TagUsage(BaseModel):
|
class TagUsage(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
@@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
|||||||
tags: list[TagUsage] = Field(default_factory=list)
|
tags: list[TagUsage] = Field(default_factory=list)
|
||||||
total: int
|
total: int
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class TagsAdd(BaseModel):
|
||||||
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
added: list[str] = Field(default_factory=list)
|
||||||
|
already_present: list[str] = Field(default_factory=list)
|
||||||
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRemove(BaseModel):
|
||||||
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
removed: list[str] = Field(default_factory=list)
|
||||||
|
not_present: list[str] = Field(default_factory=list)
|
||||||
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from sqlalchemy import select, exists, func
|
from datetime import datetime
|
||||||
|
from typing import Iterable, Any
|
||||||
|
from sqlalchemy import select, delete, exists, func
|
||||||
|
from sqlalchemy.dialects import sqlite
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session, contains_eager, noload
|
from sqlalchemy.orm import Session, contains_eager, noload
|
||||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
from app.assets.helpers import (
|
||||||
|
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||||
|
)
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
|
||||||
@@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
|||||||
return AssetInfo.owner_id.in_(["", owner_id])
|
return AssetInfo.owner_id.in_(["", owner_id])
|
||||||
|
|
||||||
|
|
||||||
|
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||||
|
"""
|
||||||
|
Return the best on-disk path among cache states:
|
||||||
|
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||||
|
2) Otherwise, pick the first path that exists.
|
||||||
|
3) Otherwise return empty string.
|
||||||
|
"""
|
||||||
|
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||||
|
if not alive:
|
||||||
|
return ""
|
||||||
|
for s in alive:
|
||||||
|
if not getattr(s, "needs_verify", False):
|
||||||
|
return s.file_path
|
||||||
|
return alive[0].file_path
|
||||||
|
|
||||||
|
|
||||||
def apply_tag_filters(
|
def apply_tag_filters(
|
||||||
stmt: sa.sql.Select,
|
stmt: sa.sql.Select,
|
||||||
include_tags: Sequence[str] | None = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
@@ -42,6 +66,7 @@ def apply_tag_filters(
|
|||||||
)
|
)
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def apply_metadata_filter(
|
def apply_metadata_filter(
|
||||||
stmt: sa.sql.Select,
|
stmt: sa.sql.Select,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
@@ -94,7 +119,11 @@ def apply_metadata_filter(
|
|||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
def asset_exists_by_hash(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an asset with a given hash exists in database.
|
Check if an asset with a given hash exists in database.
|
||||||
"""
|
"""
|
||||||
@@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
|||||||
).first()
|
).first()
|
||||||
return row is not None
|
return row is not None
|
||||||
|
|
||||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
|
||||||
|
def asset_info_exists_for_asset_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
) -> bool:
|
||||||
|
q = (
|
||||||
|
select(sa.literal(True))
|
||||||
|
.select_from(AssetInfo)
|
||||||
|
.where(AssetInfo.asset_id == asset_id)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
return (session.execute(q)).first() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_asset_by_hash(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
) -> Asset | None:
|
||||||
|
return (
|
||||||
|
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_asset_info_by_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
) -> AssetInfo | None:
|
||||||
return session.get(AssetInfo, asset_info_id)
|
return session.get(AssetInfo, asset_info_id)
|
||||||
|
|
||||||
|
|
||||||
def list_asset_infos_page(
|
def list_asset_infos_page(
|
||||||
session: Session,
|
session: Session,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
@@ -171,12 +230,14 @@ def list_asset_infos_page(
|
|||||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||||
|
.order_by(AssetInfoTag.added_at)
|
||||||
)
|
)
|
||||||
for aid, tag_name in rows.all():
|
for aid, tag_name in rows.all():
|
||||||
tag_map[aid].append(tag_name)
|
tag_map[aid].append(tag_name)
|
||||||
|
|
||||||
return infos, tag_map, total
|
return infos, tag_map, total
|
||||||
|
|
||||||
|
|
||||||
def fetch_asset_info_asset_and_tags(
|
def fetch_asset_info_asset_and_tags(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_info_id: str,
|
asset_info_id: str,
|
||||||
@@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
|
|||||||
tags.append(tag_name)
|
tags.append(tag_name)
|
||||||
return first_info, first_asset, tags
|
return first_info, first_asset, tags
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_asset_info_and_asset(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> tuple[AssetInfo, Asset] | None:
|
||||||
|
stmt = (
|
||||||
|
select(AssetInfo, Asset)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
)
|
||||||
|
row = session.execute(stmt)
|
||||||
|
pair = row.first()
|
||||||
|
if not pair:
|
||||||
|
return None
|
||||||
|
return pair[0], pair[1]
|
||||||
|
|
||||||
|
def list_cache_states_by_asset_id(
|
||||||
|
session: Session, *, asset_id: str
|
||||||
|
) -> Sequence[AssetCacheState]:
|
||||||
|
return (
|
||||||
|
session.execute(
|
||||||
|
select(AssetCacheState)
|
||||||
|
.where(AssetCacheState.asset_id == asset_id)
|
||||||
|
.order_by(AssetCacheState.id.asc())
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
def touch_asset_info_by_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
ts: datetime | None = None,
|
||||||
|
only_if_newer: bool = True,
|
||||||
|
) -> None:
|
||||||
|
ts = ts or utcnow()
|
||||||
|
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||||
|
if only_if_newer:
|
||||||
|
stmt = stmt.where(
|
||||||
|
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||||
|
)
|
||||||
|
session.execute(stmt.values(last_access_time=ts))
|
||||||
|
|
||||||
|
|
||||||
|
def create_asset_info_for_existing_asset(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
name: str,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
tags: Sequence[str] | None = None,
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> AssetInfo:
|
||||||
|
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||||
|
now = utcnow()
|
||||||
|
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||||
|
if not asset:
|
||||||
|
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||||
|
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
preview_id=None,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with session.begin_nested():
|
||||||
|
session.add(info)
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
existing = (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfo)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
.where(
|
||||||
|
AssetInfo.asset_id == asset.id,
|
||||||
|
AssetInfo.name == name,
|
||||||
|
AssetInfo.owner_id == owner_id,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).unique().scalars().first()
|
||||||
|
if not existing:
|
||||||
|
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# metadata["filename"] hack
|
||||||
|
new_meta = dict(user_metadata or {})
|
||||||
|
computed_filename = None
|
||||||
|
try:
|
||||||
|
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||||
|
if p:
|
||||||
|
computed_filename = compute_relative_filename(p)
|
||||||
|
except Exception:
|
||||||
|
computed_filename = None
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
if new_meta:
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=info.id,
|
||||||
|
user_metadata=new_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
set_asset_info_tags(
|
||||||
|
session,
|
||||||
|
asset_info_id=info.id,
|
||||||
|
tags=tags,
|
||||||
|
origin=tag_origin,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def set_asset_info_tags(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
) -> dict:
|
||||||
|
desired = normalize_tags(tags)
|
||||||
|
|
||||||
|
current = set(
|
||||||
|
tag_name for (tag_name,) in (
|
||||||
|
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
to_add = [t for t in desired if t not in current]
|
||||||
|
to_remove = [t for t in current if t not in desired]
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
ensure_tags_exist(session, to_add, tag_type="user")
|
||||||
|
session.add_all([
|
||||||
|
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||||
|
for t in to_add
|
||||||
|
])
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||||
|
|
||||||
|
|
||||||
|
def replace_asset_info_metadata_projection(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
info.user_metadata = user_metadata or {}
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
if not user_metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows: list[AssetInfoMeta] = []
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
for r in project_kv(k, v):
|
||||||
|
rows.append(
|
||||||
|
AssetInfoMeta(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
key=r["key"],
|
||||||
|
ordinal=int(r["ordinal"]),
|
||||||
|
val_str=r.get("val_str"),
|
||||||
|
val_num=r.get("val_num"),
|
||||||
|
val_bool=r.get("val_bool"),
|
||||||
|
val_json=r.get("val_json"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
session.add_all(rows)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_fs_asset(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
abs_path: str,
|
||||||
|
size_bytes: int,
|
||||||
|
mtime_ns: int,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
info_name: str | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
preview_id: str | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
tags: Sequence[str] = (),
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
require_existing_tags: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Idempotently upsert:
|
||||||
|
- Asset by content hash (create if missing)
|
||||||
|
- AssetCacheState(file_path) pointing to asset_id
|
||||||
|
- Optionally AssetInfo + tag links and metadata projection
|
||||||
|
Returns flags and ids.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
now = utcnow()
|
||||||
|
|
||||||
|
if preview_id:
|
||||||
|
if not session.get(Asset, preview_id):
|
||||||
|
preview_id = None
|
||||||
|
|
||||||
|
out: dict[str, Any] = {
|
||||||
|
"asset_created": False,
|
||||||
|
"asset_updated": False,
|
||||||
|
"state_created": False,
|
||||||
|
"state_updated": False,
|
||||||
|
"asset_info_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1) Asset by hash
|
||||||
|
asset = (
|
||||||
|
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
vals = {
|
||||||
|
"hash": asset_hash,
|
||||||
|
"size_bytes": int(size_bytes),
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
res = session.execute(
|
||||||
|
sqlite.insert(Asset)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||||
|
)
|
||||||
|
if int(res.rowcount or 0) > 0:
|
||||||
|
out["asset_created"] = True
|
||||||
|
asset = (
|
||||||
|
session.execute(
|
||||||
|
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||||
|
)
|
||||||
|
).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
raise RuntimeError("Asset row not found after upsert.")
|
||||||
|
else:
|
||||||
|
changed = False
|
||||||
|
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||||
|
asset.size_bytes = int(size_bytes)
|
||||||
|
changed = True
|
||||||
|
if mime_type and asset.mime_type != mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
out["asset_updated"] = True
|
||||||
|
|
||||||
|
# 2) AssetCacheState upsert by file_path (unique)
|
||||||
|
vals = {
|
||||||
|
"asset_id": asset.id,
|
||||||
|
"file_path": locator,
|
||||||
|
"mtime_ns": int(mtime_ns),
|
||||||
|
}
|
||||||
|
ins = (
|
||||||
|
sqlite.insert(AssetCacheState)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||||
|
)
|
||||||
|
|
||||||
|
res = session.execute(ins)
|
||||||
|
if int(res.rowcount or 0) > 0:
|
||||||
|
out["state_created"] = True
|
||||||
|
else:
|
||||||
|
upd = (
|
||||||
|
sa.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.file_path == locator)
|
||||||
|
.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetCacheState.asset_id != asset.id,
|
||||||
|
AssetCacheState.mtime_ns.is_(None),
|
||||||
|
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||||
|
)
|
||||||
|
res2 = session.execute(upd)
|
||||||
|
if int(res2.rowcount or 0) > 0:
|
||||||
|
out["state_updated"] = True
|
||||||
|
|
||||||
|
# 3) Optional AssetInfo + tags + metadata
|
||||||
|
if info_name:
|
||||||
|
try:
|
||||||
|
with session.begin_nested():
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=info_name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
preview_id=preview_id,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(info)
|
||||||
|
session.flush()
|
||||||
|
out["asset_info_id"] = info.id
|
||||||
|
except IntegrityError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
existing_info = (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfo)
|
||||||
|
.where(
|
||||||
|
AssetInfo.asset_id == asset.id,
|
||||||
|
AssetInfo.name == info_name,
|
||||||
|
(AssetInfo.owner_id == owner_id),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).unique().scalar_one_or_none()
|
||||||
|
if not existing_info:
|
||||||
|
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||||
|
|
||||||
|
if preview_id and existing_info.preview_id != preview_id:
|
||||||
|
existing_info.preview_id = preview_id
|
||||||
|
|
||||||
|
existing_info.updated_at = now
|
||||||
|
if existing_info.last_access_time < now:
|
||||||
|
existing_info.last_access_time = now
|
||||||
|
session.flush()
|
||||||
|
out["asset_info_id"] = existing_info.id
|
||||||
|
|
||||||
|
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
if norm and out["asset_info_id"] is not None:
|
||||||
|
if not require_existing_tags:
|
||||||
|
ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
|
existing_tag_names = set(
|
||||||
|
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||||
|
)
|
||||||
|
missing = [t for t in norm if t not in existing_tag_names]
|
||||||
|
if missing and require_existing_tags:
|
||||||
|
raise ValueError(f"Unknown tags: {missing}")
|
||||||
|
|
||||||
|
existing_links = set(
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||||
|
if to_add:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
tag_name=t,
|
||||||
|
origin=tag_origin,
|
||||||
|
added_at=now,
|
||||||
|
)
|
||||||
|
for t in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
# metadata["filename"] hack
|
||||||
|
if out["asset_info_id"] is not None:
|
||||||
|
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||||
|
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||||
|
|
||||||
|
current_meta = existing_info.user_metadata or {}
|
||||||
|
new_meta = dict(current_meta)
|
||||||
|
if user_metadata is not None:
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
new_meta[k] = v
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
|
||||||
|
if new_meta != current_meta:
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
user_metadata=new_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def update_asset_info_full(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
tags: Sequence[str] | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
asset_info_row: Any = None,
|
||||||
|
) -> AssetInfo:
|
||||||
|
if not asset_info_row:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
else:
|
||||||
|
info = asset_info_row
|
||||||
|
|
||||||
|
touched = False
|
||||||
|
if name is not None and name != info.name:
|
||||||
|
info.name = name
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
computed_filename = None
|
||||||
|
try:
|
||||||
|
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||||
|
if p:
|
||||||
|
computed_filename = compute_relative_filename(p)
|
||||||
|
except Exception:
|
||||||
|
computed_filename = None
|
||||||
|
|
||||||
|
if user_metadata is not None:
|
||||||
|
new_meta = dict(user_metadata)
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
else:
|
||||||
|
if computed_filename:
|
||||||
|
current_meta = info.user_metadata or {}
|
||||||
|
if current_meta.get("filename") != computed_filename:
|
||||||
|
new_meta = dict(current_meta)
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
set_asset_info_tags(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
origin=tag_origin,
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if touched and user_metadata is None:
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def delete_asset_info_by_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str,
|
||||||
|
) -> bool:
|
||||||
|
stmt = sa.delete(AssetInfo).where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
|
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||||
|
|
||||||
|
|
||||||
def list_tags_with_usage(
|
def list_tags_with_usage(
|
||||||
session: Session,
|
session: Session,
|
||||||
prefix: str | None = None,
|
prefix: str | None = None,
|
||||||
@@ -265,3 +814,163 @@ def list_tags_with_usage(
|
|||||||
|
|
||||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||||
return rows_norm, int(total or 0)
|
return rows_norm, int(total or 0)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||||
|
wanted = normalize_tags(list(names))
|
||||||
|
if not wanted:
|
||||||
|
return
|
||||||
|
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||||
|
ins = (
|
||||||
|
sqlite.insert(Tag)
|
||||||
|
.values(rows)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||||
|
)
|
||||||
|
session.execute(ins)
|
||||||
|
|
||||||
|
|
||||||
|
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||||
|
return [
|
||||||
|
tag_name for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_tags_to_asset_info(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
create_if_missing: bool = True,
|
||||||
|
asset_info_row: Any = None,
|
||||||
|
) -> dict:
|
||||||
|
if not asset_info_row:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"added": [], "already_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
if create_if_missing:
|
||||||
|
ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
|
current = {
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
want = set(norm)
|
||||||
|
to_add = sorted(want - current)
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
with session.begin_nested() as nested:
|
||||||
|
try:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tag_name=t,
|
||||||
|
origin=origin,
|
||||||
|
added_at=utcnow(),
|
||||||
|
)
|
||||||
|
for t in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
nested.rollback()
|
||||||
|
|
||||||
|
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||||
|
return {
|
||||||
|
"added": sorted(((after - current) & want)),
|
||||||
|
"already_present": sorted(want & current),
|
||||||
|
"total_tags": sorted(after),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tags_from_asset_info(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
) -> dict:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": [], "not_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
existing = {
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
to_remove = sorted(set(t for t in norm if t in existing))
|
||||||
|
not_present = sorted(set(t for t in norm if t not in existing))
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(
|
||||||
|
AssetInfoTag.asset_info_id == asset_info_id,
|
||||||
|
AssetInfoTag.tag_name.in_(to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_missing_tag_for_asset_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
) -> None:
|
||||||
|
session.execute(
|
||||||
|
sa.delete(AssetInfoTag).where(
|
||||||
|
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||||
|
AssetInfoTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_asset_info_preview(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
preview_asset_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
if preview_asset_id is None:
|
||||||
|
info.preview_id = None
|
||||||
|
else:
|
||||||
|
# validate preview asset exists
|
||||||
|
if not session.get(Asset, preview_asset_id):
|
||||||
|
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||||
|
info.preview_id = preview_asset_id
|
||||||
|
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
session.flush()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
from decimal import Decimal
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
|||||||
targets.append((name, paths))
|
targets.append((name, paths))
|
||||||
return targets
|
return targets
|
||||||
|
|
||||||
|
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||||
|
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||||
|
root = tags[0]
|
||||||
|
if root == "models":
|
||||||
|
if len(tags) < 2:
|
||||||
|
raise ValueError("at least two tags required for model asset")
|
||||||
|
try:
|
||||||
|
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||||
|
if not bases:
|
||||||
|
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||||
|
base_dir = os.path.abspath(bases[0])
|
||||||
|
raw_subdirs = tags[2:]
|
||||||
|
else:
|
||||||
|
base_dir = os.path.abspath(
|
||||||
|
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||||
|
)
|
||||||
|
raw_subdirs = tags[1:]
|
||||||
|
for i in raw_subdirs:
|
||||||
|
if i in (".", ".."):
|
||||||
|
raise ValueError("invalid path component in tags")
|
||||||
|
|
||||||
|
return base_dir, raw_subdirs if raw_subdirs else []
|
||||||
|
|
||||||
|
def ensure_within_base(candidate: str, base: str) -> None:
|
||||||
|
cand_abs = os.path.abspath(candidate)
|
||||||
|
base_abs = os.path.abspath(base)
|
||||||
|
try:
|
||||||
|
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||||
|
raise ValueError("destination escapes base directory")
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("invalid destination path")
|
||||||
|
|
||||||
def compute_relative_filename(file_path: str) -> str | None:
|
def compute_relative_filename(file_path: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Return the model's path relative to the last well-known folder (the model category),
|
Return the model's path relative to the last well-known folder (the model category),
|
||||||
@@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
|||||||
return "/".join(inside)
|
return "/".join(inside)
|
||||||
return "/".join(parts) # input/output: keep all parts
|
return "/".join(parts) # input/output: keep all parts
|
||||||
|
|
||||||
|
|
||||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||||
@@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
|||||||
if allowed:
|
if allowed:
|
||||||
out.append(abs_path)
|
out.append(abs_path)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def is_scalar(v):
|
||||||
|
if v is None:
|
||||||
|
return True
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return True
|
||||||
|
if isinstance(v, (int, float, Decimal, str)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def project_kv(key: str, value):
|
||||||
|
"""
|
||||||
|
Turn a metadata key/value into typed projection rows.
|
||||||
|
Returns list[dict] with keys:
|
||||||
|
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||||
|
"""
|
||||||
|
rows: list[dict] = []
|
||||||
|
|
||||||
|
def _null_row(ordinal: int) -> dict:
|
||||||
|
return {
|
||||||
|
"key": key, "ordinal": ordinal,
|
||||||
|
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
rows.append(_null_row(0))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
if is_scalar(value):
|
||||||
|
if isinstance(value, bool):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||||
|
elif isinstance(value, (int, float, Decimal)):
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||||
|
elif isinstance(value, str):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||||
|
else:
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
if all(is_scalar(x) for x in value):
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
if x is None:
|
||||||
|
rows.append(_null_row(i))
|
||||||
|
elif isinstance(x, bool):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||||
|
elif isinstance(x, (int, float, Decimal)):
|
||||||
|
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||||
|
elif isinstance(x, str):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||||
|
else:
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
|
|||||||
@@ -1,13 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import mimetypes
|
||||||
|
import contextlib
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from app.database.db import create_session
|
from app.database.db import create_session
|
||||||
from app.assets.api import schemas_out
|
from app.assets.api import schemas_out, schemas_in
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
asset_exists_by_hash,
|
asset_exists_by_hash,
|
||||||
|
asset_info_exists_for_asset_id,
|
||||||
|
get_asset_by_hash,
|
||||||
|
get_asset_info_by_id,
|
||||||
fetch_asset_info_asset_and_tags,
|
fetch_asset_info_asset_and_tags,
|
||||||
|
fetch_asset_info_and_asset,
|
||||||
|
create_asset_info_for_existing_asset,
|
||||||
|
touch_asset_info_by_id,
|
||||||
|
update_asset_info_full,
|
||||||
|
delete_asset_info_by_id,
|
||||||
|
list_cache_states_by_asset_id,
|
||||||
list_asset_infos_page,
|
list_asset_infos_page,
|
||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
|
get_asset_tags,
|
||||||
|
add_tags_to_asset_info,
|
||||||
|
remove_tags_from_asset_info,
|
||||||
|
pick_best_live_path,
|
||||||
|
ingest_fs_asset,
|
||||||
|
set_asset_info_preview,
|
||||||
)
|
)
|
||||||
|
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||||
|
from app.assets.database.models import Asset
|
||||||
|
|
||||||
|
|
||||||
def _safe_sort_field(requested: str | None) -> str:
|
def _safe_sort_field(requested: str | None) -> str:
|
||||||
@@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
|||||||
return "created_at"
|
return "created_at"
|
||||||
|
|
||||||
|
|
||||||
def asset_exists(asset_hash: str) -> bool:
|
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||||
|
st = os.stat(path, follow_symlinks=True)
|
||||||
|
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||||
|
n = os.path.basename((name or "").strip() or fallback)
|
||||||
|
if n:
|
||||||
|
return n
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def asset_exists(*, asset_hash: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an asset with a given hash exists in database.
|
||||||
|
"""
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||||
|
|
||||||
|
|
||||||
def list_assets(
|
def list_assets(
|
||||||
|
*,
|
||||||
include_tags: Sequence[str] | None = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
name_contains: str | None = None,
|
name_contains: str | None = None,
|
||||||
@@ -63,7 +100,6 @@ def list_assets(
|
|||||||
size=int(asset.size_bytes) if asset else None,
|
size=int(asset.size_bytes) if asset else None,
|
||||||
mime_type=asset.mime_type if asset else None,
|
mime_type=asset.mime_type if asset else None,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
preview_url=f"/api/assets/{info.id}/content",
|
|
||||||
created_at=info.created_at,
|
created_at=info.created_at,
|
||||||
updated_at=info.updated_at,
|
updated_at=info.updated_at,
|
||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
@@ -76,7 +112,12 @@ def list_assets(
|
|||||||
has_more=(offset + len(summaries)) < total,
|
has_more=(offset + len(summaries)) < total,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
|
||||||
|
def get_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetDetail:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
if not res:
|
if not res:
|
||||||
@@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
|||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_asset_content_for_download(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> tuple[str, str, str]:
|
||||||
|
with create_session() as session:
|
||||||
|
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not pair:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
info, asset = pair
|
||||||
|
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||||
|
abs_path = pick_best_live_path(states)
|
||||||
|
if not abs_path:
|
||||||
|
raise FileNotFoundError
|
||||||
|
|
||||||
|
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||||
|
download_name = info.name or os.path.basename(abs_path)
|
||||||
|
return abs_path, ctype, download_name
|
||||||
|
|
||||||
|
|
||||||
|
def upload_asset_from_temp_path(
|
||||||
|
spec: schemas_in.UploadAssetSpec,
|
||||||
|
*,
|
||||||
|
temp_path: str,
|
||||||
|
client_filename: str | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
expected_asset_hash: str | None = None,
|
||||||
|
) -> schemas_out.AssetCreated:
|
||||||
|
"""
|
||||||
|
Create new asset or update existing asset from a temporary file path.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||||
|
import app.assets.hashing as hashing
|
||||||
|
digest = hashing.blake3_hash(temp_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||||
|
asset_hash = "blake3:" + digest
|
||||||
|
|
||||||
|
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||||
|
raise ValueError("HASH_MISMATCH")
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||||
|
if existing is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||||
|
info = create_asset_info_for_existing_asset(
|
||||||
|
session,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
name=display_name,
|
||||||
|
user_metadata=spec.user_metadata or {},
|
||||||
|
tags=spec.tags or [],
|
||||||
|
tag_origin="manual",
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return schemas_out.AssetCreated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=existing.hash,
|
||||||
|
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||||
|
mime_type=existing.mime_type,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
created_new=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||||
|
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||||
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
|
|
||||||
|
src_for_ext = (client_filename or spec.name or "").strip()
|
||||||
|
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||||
|
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||||
|
hashed_basename = f"{digest}{ext}"
|
||||||
|
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||||
|
ensure_within_base(dest_abs, base_dir)
|
||||||
|
|
||||||
|
content_type = (
|
||||||
|
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||||
|
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.replace(temp_path, dest_abs)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||||
|
except OSError as e:
|
||||||
|
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
result = ingest_fs_asset(
|
||||||
|
session,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
abs_path=dest_abs,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mtime_ns=mtime_ns,
|
||||||
|
mime_type=content_type,
|
||||||
|
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||||
|
owner_id=owner_id,
|
||||||
|
preview_id=None,
|
||||||
|
user_metadata=spec.user_metadata or {},
|
||||||
|
tags=spec.tags,
|
||||||
|
tag_origin="manual",
|
||||||
|
require_existing_tags=False,
|
||||||
|
)
|
||||||
|
info_id = result["asset_info_id"]
|
||||||
|
if not info_id:
|
||||||
|
raise RuntimeError("failed to create asset metadata")
|
||||||
|
|
||||||
|
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||||
|
if not pair:
|
||||||
|
raise RuntimeError("inconsistent DB state after ingest")
|
||||||
|
info, asset = pair
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||||
|
created_result = schemas_out.AssetCreated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash,
|
||||||
|
size=int(asset.size_bytes),
|
||||||
|
mime_type=asset.mime_type,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
created_new=result["asset_created"],
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return created_result
|
||||||
|
|
||||||
|
|
||||||
|
def update_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetUpdated:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
|
info = update_asset_info_full(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
name=name,
|
||||||
|
tags=tags,
|
||||||
|
user_metadata=user_metadata,
|
||||||
|
tag_origin="manual",
|
||||||
|
asset_info_row=info_row,
|
||||||
|
)
|
||||||
|
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
result = schemas_out.AssetUpdated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=info.asset.hash if info.asset else None,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
updated_at=info.updated_at,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def set_asset_preview(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
preview_asset_id: str | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetDetail:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
|
set_asset_info_preview(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
preview_asset_id=preview_asset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not res:
|
||||||
|
raise RuntimeError("State changed during preview update")
|
||||||
|
info, asset, tags = res
|
||||||
|
result = schemas_out.AssetDetail(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash if asset else None,
|
||||||
|
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||||
|
mime_type=asset.mime_type if asset else None,
|
||||||
|
tags=tags,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
asset_id = info_row.asset_id if info_row else None
|
||||||
|
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not deleted:
|
||||||
|
session.commit()
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not delete_content_if_orphan or not asset_id:
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||||
|
if still_exists:
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||||
|
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||||
|
|
||||||
|
asset_row = session.get(Asset, asset_id)
|
||||||
|
if asset_row is not None:
|
||||||
|
session.delete(asset_row)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
for p in file_paths:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if p and os.path.isfile(p):
|
||||||
|
os.remove(p)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_asset_from_hash(
|
||||||
|
*,
|
||||||
|
hash_str: str,
|
||||||
|
name: str,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetCreated | None:
|
||||||
|
canonical = hash_str.strip().lower()
|
||||||
|
with create_session() as session:
|
||||||
|
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||||
|
if not asset:
|
||||||
|
return None
|
||||||
|
|
||||||
|
info = create_asset_info_for_existing_asset(
|
||||||
|
session,
|
||||||
|
asset_hash=canonical,
|
||||||
|
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||||
|
user_metadata=user_metadata or {},
|
||||||
|
tags=tags or [],
|
||||||
|
tag_origin="manual",
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||||
|
result = schemas_out.AssetCreated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash,
|
||||||
|
size=int(asset.size_bytes),
|
||||||
|
mime_type=asset.mime_type,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
created_new=False,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def add_tags_to_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: list[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.TagsAdd:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
data = add_tags_to_asset_info(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
origin=origin,
|
||||||
|
create_if_missing=True,
|
||||||
|
asset_info_row=info_row,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return schemas_out.TagsAdd(**data)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tags_from_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: list[str],
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.TagsRemove:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
|
data = remove_tags_from_asset_info(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return schemas_out.TagsRemove(**data)
|
||||||
|
|
||||||
|
|
||||||
def list_tags(
|
def list_tags(
|
||||||
prefix: str | None = None,
|
prefix: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
|||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
created = 0
|
created = 0
|
||||||
skipped_existing = 0
|
skipped_existing = 0
|
||||||
|
orphans_pruned = 0
|
||||||
paths: list[str] = []
|
paths: list[str] = []
|
||||||
try:
|
try:
|
||||||
existing_paths: set[str] = set()
|
existing_paths: set[str] = set()
|
||||||
@@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
orphans_pruned = _prune_orphaned_assets(roots)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("orphan pruning failed: %s", e)
|
||||||
|
|
||||||
if "models" in roots:
|
if "models" in roots:
|
||||||
paths.extend(collect_models_files())
|
paths.extend(collect_models_files())
|
||||||
if "input" in roots:
|
if "input" in roots:
|
||||||
@@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
|||||||
finally:
|
finally:
|
||||||
if enable_logging:
|
if enable_logging:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||||
roots,
|
roots,
|
||||||
time.perf_counter() - t_start,
|
time.perf_counter() - t_start,
|
||||||
created,
|
created,
|
||||||
skipped_existing,
|
skipped_existing,
|
||||||
|
orphans_pruned,
|
||||||
len(paths),
|
len(paths),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||||
|
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||||
|
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||||
|
if not all_prefixes:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def make_prefix_condition(prefix: str):
|
||||||
|
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||||
|
escaped, esc = escape_like_prefix(base)
|
||||||
|
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||||
|
|
||||||
|
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||||
|
|
||||||
|
orphan_subq = (
|
||||||
|
sqlalchemy.select(Asset.id)
|
||||||
|
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||||
|
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||||
|
).scalar_subquery()
|
||||||
|
|
||||||
|
with create_session() as sess:
|
||||||
|
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||||
|
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||||
|
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||||
|
sess.commit()
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
|
||||||
def _fast_db_consistency_pass(
|
def _fast_db_consistency_pass(
|
||||||
root: RootType,
|
root: RootType,
|
||||||
*,
|
*,
|
||||||
|
|||||||
105
app/node_replace_manager.py
Normal file
105
app/node_replace_manager.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy_api.latest._io_public import NodeReplace
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
class NodeStruct(TypedDict):
|
||||||
|
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||||
|
class_type: str
|
||||||
|
_meta: dict[str, str]
|
||||||
|
|
||||||
|
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||||
|
new_node_struct = node_struct.copy()
|
||||||
|
if empty_inputs:
|
||||||
|
new_node_struct["inputs"] = {}
|
||||||
|
else:
|
||||||
|
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||||
|
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||||
|
return new_node_struct
|
||||||
|
|
||||||
|
|
||||||
|
class NodeReplaceManager:
|
||||||
|
"""Manages node replacement registrations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||||
|
|
||||||
|
def register(self, node_replace: NodeReplace):
|
||||||
|
"""Register a node replacement mapping."""
|
||||||
|
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||||
|
|
||||||
|
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||||
|
"""Get replacements for an old node ID."""
|
||||||
|
return self._replacements.get(old_node_id)
|
||||||
|
|
||||||
|
def has_replacement(self, old_node_id: str) -> bool:
|
||||||
|
"""Check if a replacement exists for an old node ID."""
|
||||||
|
return old_node_id in self._replacements
|
||||||
|
|
||||||
|
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||||
|
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||||
|
need_replacement: set[str] = set()
|
||||||
|
for node_number, node_struct in prompt.items():
|
||||||
|
class_type = node_struct["class_type"]
|
||||||
|
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||||
|
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||||
|
need_replacement.add(node_number)
|
||||||
|
# keep track of connections
|
||||||
|
for input_id, input_value in node_struct["inputs"].items():
|
||||||
|
if is_link(input_value):
|
||||||
|
conn_number = input_value[0]
|
||||||
|
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||||
|
for node_number in need_replacement:
|
||||||
|
node_struct = prompt[node_number]
|
||||||
|
class_type = node_struct["class_type"]
|
||||||
|
replacements = self.get_replacement(class_type)
|
||||||
|
if replacements is None:
|
||||||
|
continue
|
||||||
|
# just use the first replacement
|
||||||
|
replacement = replacements[0]
|
||||||
|
new_node_id = replacement.new_node_id
|
||||||
|
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||||
|
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||||
|
continue
|
||||||
|
# first, replace node id (class_type)
|
||||||
|
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||||
|
new_node_struct["class_type"] = new_node_id
|
||||||
|
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||||
|
# second, replace inputs
|
||||||
|
if replacement.input_mapping is not None:
|
||||||
|
for input_map in replacement.input_mapping:
|
||||||
|
if "set_value" in input_map:
|
||||||
|
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||||
|
elif "old_id" in input_map:
|
||||||
|
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||||
|
# finalize input replacement
|
||||||
|
prompt[node_number] = new_node_struct
|
||||||
|
# third, replace outputs
|
||||||
|
if replacement.output_mapping is not None:
|
||||||
|
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||||
|
if node_number in connections:
|
||||||
|
for conns in connections[node_number]:
|
||||||
|
conn_node_number, conn_input_id, old_output_idx = conns
|
||||||
|
for output_map in replacement.output_mapping:
|
||||||
|
if output_map["old_idx"] == old_output_idx:
|
||||||
|
new_output_idx = output_map["new_idx"]
|
||||||
|
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||||
|
previous_input[1] = new_output_idx
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
"""Serialize all replacements to dict."""
|
||||||
|
return {
|
||||||
|
k: [v.as_dict() for v in v_list]
|
||||||
|
for k, v_list in self._replacements.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_routes(self, routes):
|
||||||
|
@routes.get("/node_replacements")
|
||||||
|
async def get_node_replacements(request):
|
||||||
|
return web.json_response(self.as_dict())
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict, NotRequired
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import glob
|
import glob
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@@ -31,6 +32,7 @@ class SubgraphEntry(TypedDict):
|
|||||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||||
"""
|
"""
|
||||||
data: str
|
data: str
|
||||||
|
essentials_category: NotRequired[str]
|
||||||
|
|
||||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||||
node_pack: str
|
node_pack: str
|
||||||
@@ -53,7 +55,7 @@ class SubgraphManager:
|
|||||||
return entry_id, entry
|
return entry_id, entry
|
||||||
|
|
||||||
async def load_entry_data(self, entry: SubgraphEntry):
|
async def load_entry_data(self, entry: SubgraphEntry):
|
||||||
with open(entry['path'], 'r') as f:
|
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||||
entry['data'] = f.read()
|
entry['data'] = f.read()
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
@@ -101,6 +103,16 @@ class SubgraphManager:
|
|||||||
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
|
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
|
||||||
file = file.replace('\\', '/')
|
file = file.replace('\\', '/')
|
||||||
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
|
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
|
||||||
|
try:
|
||||||
|
with open(file, 'r', encoding='utf-8') as f:
|
||||||
|
bp_data = json.load(f)
|
||||||
|
subgraphs = bp_data.get('definitions', {}).get('subgraphs', [])
|
||||||
|
if subgraphs:
|
||||||
|
ec = subgraphs[0].get('essentials_category')
|
||||||
|
if ec:
|
||||||
|
entry['essentials_category'] = ec
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
subgraphs_dict[entry_id] = entry
|
subgraphs_dict[entry_id] = entry
|
||||||
|
|
||||||
self.cached_blueprint_subgraphs = subgraphs_dict
|
self.cached_blueprint_subgraphs = subgraphs_dict
|
||||||
|
|||||||
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0; // Brightness slider -100..100
|
||||||
|
uniform float u_float1; // Contrast slider -100..100
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const float MID_GRAY = 0.18; // 18% reflectance
|
||||||
|
|
||||||
|
// sRGB gamma 2.2 approximation
|
||||||
|
vec3 srgbToLinear(vec3 c) {
|
||||||
|
return pow(max(c, 0.0), vec3(2.2));
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 linearToSrgb(vec3 c) {
|
||||||
|
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||||
|
}
|
||||||
|
|
||||||
|
float mapBrightness(float b) {
|
||||||
|
return clamp(b / 100.0, -1.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
float mapContrast(float c) {
|
||||||
|
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 orig = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
float brightness = mapBrightness(u_float0);
|
||||||
|
float contrast = mapContrast(u_float1);
|
||||||
|
|
||||||
|
vec3 lin = srgbToLinear(orig.rgb);
|
||||||
|
|
||||||
|
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||||
|
|
||||||
|
// Convert back to sRGB
|
||||||
|
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||||
|
|
||||||
|
fragColor = vec4(result, orig.a);
|
||||||
|
}
|
||||||
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform int u_int0; // Mode
|
||||||
|
uniform float u_float0; // Amount (0 to 100)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const int MODE_LINEAR = 0;
|
||||||
|
const int MODE_RADIAL = 1;
|
||||||
|
const int MODE_BARREL = 2;
|
||||||
|
const int MODE_SWIRL = 3;
|
||||||
|
const int MODE_DIAGONAL = 4;
|
||||||
|
|
||||||
|
const float AMOUNT_SCALE = 0.0005;
|
||||||
|
const float RADIAL_MULT = 4.0;
|
||||||
|
const float BARREL_MULT = 8.0;
|
||||||
|
const float INV_SQRT2 = 0.70710678118;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 uv = v_texCoord;
|
||||||
|
vec4 original = texture(u_image0, uv);
|
||||||
|
|
||||||
|
float amount = u_float0 * AMOUNT_SCALE;
|
||||||
|
|
||||||
|
if (amount < 0.000001) {
|
||||||
|
fragColor = original;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aspect-corrected coordinates for circular effects
|
||||||
|
float aspect = u_resolution.x / u_resolution.y;
|
||||||
|
vec2 centered = uv - 0.5;
|
||||||
|
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||||
|
float r = length(corrected);
|
||||||
|
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||||
|
vec2 offset = vec2(0.0);
|
||||||
|
|
||||||
|
if (u_int0 == MODE_LINEAR) {
|
||||||
|
// Horizontal shift (no aspect correction needed)
|
||||||
|
offset = vec2(amount, 0.0);
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_RADIAL) {
|
||||||
|
// Outward from center, stronger at edges
|
||||||
|
offset = dir * r * amount * RADIAL_MULT;
|
||||||
|
offset.x /= aspect; // Convert back to UV space
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_BARREL) {
|
||||||
|
// Lens distortion simulation (r² falloff)
|
||||||
|
offset = dir * r * r * amount * BARREL_MULT;
|
||||||
|
offset.x /= aspect; // Convert back to UV space
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_SWIRL) {
|
||||||
|
// Perpendicular to radial (rotational aberration)
|
||||||
|
vec2 perp = vec2(-dir.y, dir.x);
|
||||||
|
offset = perp * r * amount * RADIAL_MULT;
|
||||||
|
offset.x /= aspect; // Convert back to UV space
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_DIAGONAL) {
|
||||||
|
// 45° offset (no aspect correction needed)
|
||||||
|
offset = vec2(amount, amount) * INV_SQRT2;
|
||||||
|
}
|
||||||
|
|
||||||
|
float red = texture(u_image0, uv + offset).r;
|
||||||
|
float green = original.g;
|
||||||
|
float blue = texture(u_image0, uv - offset).b;
|
||||||
|
|
||||||
|
fragColor = vec4(red, green, blue, original.a);
|
||||||
|
}
|
||||||
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0; // temperature (-100 to 100)
|
||||||
|
uniform float u_float1; // tint (-100 to 100)
|
||||||
|
uniform float u_float2; // vibrance (-100 to 100)
|
||||||
|
uniform float u_float3; // saturation (-100 to 100)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const float INPUT_SCALE = 0.01;
|
||||||
|
const float TEMP_TINT_PRIMARY = 0.3;
|
||||||
|
const float TEMP_TINT_SECONDARY = 0.15;
|
||||||
|
const float VIBRANCE_BOOST = 2.0;
|
||||||
|
const float SATURATION_BOOST = 2.0;
|
||||||
|
const float SKIN_PROTECTION = 0.5;
|
||||||
|
const float EPSILON = 0.001;
|
||||||
|
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 tex = texture(u_image0, v_texCoord);
|
||||||
|
vec3 color = tex.rgb;
|
||||||
|
|
||||||
|
// Scale inputs: -100/100 → -1/1
|
||||||
|
float temperature = u_float0 * INPUT_SCALE;
|
||||||
|
float tint = u_float1 * INPUT_SCALE;
|
||||||
|
float vibrance = u_float2 * INPUT_SCALE;
|
||||||
|
float saturation = u_float3 * INPUT_SCALE;
|
||||||
|
|
||||||
|
// Temperature (warm/cool): positive = warm, negative = cool
|
||||||
|
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||||
|
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||||
|
|
||||||
|
// Tint (green/magenta): positive = green, negative = magenta
|
||||||
|
color.g += tint * TEMP_TINT_PRIMARY;
|
||||||
|
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||||
|
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||||
|
|
||||||
|
// Single clamp after temperature/tint
|
||||||
|
color = clamp(color, 0.0, 1.0);
|
||||||
|
|
||||||
|
// Vibrance with skin protection
|
||||||
|
if (vibrance != 0.0) {
|
||||||
|
float maxC = max(color.r, max(color.g, color.b));
|
||||||
|
float minC = min(color.r, min(color.g, color.b));
|
||||||
|
float sat = maxC - minC;
|
||||||
|
float gray = dot(color, LUMA_WEIGHTS);
|
||||||
|
|
||||||
|
if (vibrance < 0.0) {
|
||||||
|
// Desaturate: -100 → gray
|
||||||
|
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||||
|
} else {
|
||||||
|
// Boost less saturated colors more
|
||||||
|
float vibranceAmt = vibrance * (1.0 - sat);
|
||||||
|
|
||||||
|
// Branchless skin tone protection
|
||||||
|
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||||
|
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||||
|
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||||
|
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||||
|
|
||||||
|
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Saturation
|
||||||
|
if (saturation != 0.0) {
|
||||||
|
float gray = dot(color, LUMA_WEIGHTS);
|
||||||
|
float satMix = saturation < 0.0
|
||||||
|
? 1.0 + saturation // -100 → gray
|
||||||
|
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||||
|
color = mix(vec3(gray), color, satMix);
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||||
|
}
|
||||||
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||||
|
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||||
|
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const int MAX_RADIUS = 20;
|
||||||
|
const float EPSILON = 0.0001;
|
||||||
|
|
||||||
|
// Perceptual luminance
|
||||||
|
float getLuminance(vec3 rgb) {
|
||||||
|
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||||
|
}
|
||||||
|
|
||||||
|
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||||
|
float sigmaSpatial, float sigmaColor)
|
||||||
|
{
|
||||||
|
vec4 center = texture(u_image0, uv);
|
||||||
|
vec3 centerRGB = center.rgb;
|
||||||
|
|
||||||
|
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||||
|
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||||
|
|
||||||
|
vec3 sumRGB = vec3(0.0);
|
||||||
|
float sumWeight = 0.0;
|
||||||
|
|
||||||
|
int step = max(u_int0, 1);
|
||||||
|
float radius2 = float(radius * radius);
|
||||||
|
|
||||||
|
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||||
|
if (dy < -radius || dy > radius) continue;
|
||||||
|
if (abs(dy) % step != 0) continue;
|
||||||
|
|
||||||
|
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||||
|
if (dx < -radius || dx > radius) continue;
|
||||||
|
if (abs(dx) % step != 0) continue;
|
||||||
|
|
||||||
|
vec2 offset = vec2(float(dx), float(dy));
|
||||||
|
float dist2 = dot(offset, offset);
|
||||||
|
if (dist2 > radius2) continue;
|
||||||
|
|
||||||
|
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||||
|
|
||||||
|
// Spatial Gaussian
|
||||||
|
float spatialWeight = exp(dist2 * invSpatial2);
|
||||||
|
|
||||||
|
// Perceptual color distance (weighted RGB)
|
||||||
|
vec3 diff = sampleRGB - centerRGB;
|
||||||
|
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||||
|
float colorWeight = exp(colorDist * invColor2);
|
||||||
|
|
||||||
|
float w = spatialWeight * colorWeight;
|
||||||
|
sumRGB += sampleRGB * w;
|
||||||
|
sumWeight += w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||||
|
return vec4(resultRGB, center.a); // preserve center alpha
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||||
|
|
||||||
|
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||||
|
int radius = int(radiusF + 0.5);
|
||||||
|
|
||||||
|
if (radius == 0) {
|
||||||
|
fragColor = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edge threshold → color sigma
|
||||||
|
// Squared curve for better low-end control
|
||||||
|
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||||
|
t *= t;
|
||||||
|
float sigmaColor = mix(0.01, 0.5, t);
|
||||||
|
|
||||||
|
// Spatial sigma tied to radius
|
||||||
|
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||||
|
|
||||||
|
fragColor = bilateralFilter(
|
||||||
|
v_texCoord,
|
||||||
|
texelSize,
|
||||||
|
radius,
|
||||||
|
sigmaSpatial,
|
||||||
|
sigmaColor
|
||||||
|
);
|
||||||
|
}
|
||||||
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||||
|
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||||
|
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||||
|
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||||
|
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
// High-quality integer hash (pcg-like)
|
||||||
|
uint pcg(uint v) {
|
||||||
|
uint state = v * 747796405u + 2891336453u;
|
||||||
|
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||||
|
return (word >> 22u) ^ word;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2D -> 1D hash input
|
||||||
|
uint hash2d(uvec2 p) {
|
||||||
|
return pcg(p.x + pcg(p.y));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash to float [0, 1]
|
||||||
|
float hashf(uvec2 p) {
|
||||||
|
return float(hash2d(p)) / float(0xffffffffu);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash to float with offset (for RGB channels)
|
||||||
|
float hashf(uvec2 p, uint offset) {
|
||||||
|
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||||
|
// Using simple approximation: average of multiple samples
|
||||||
|
float toGaussian(uvec2 p) {
|
||||||
|
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||||
|
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||||
|
}
|
||||||
|
|
||||||
|
float toGaussian(uvec2 p, uint offset) {
|
||||||
|
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||||
|
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||||
|
return (sum - 2.0) * 0.7;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Smooth noise with better interpolation
|
||||||
|
float smoothNoise(vec2 p) {
|
||||||
|
vec2 i = floor(p);
|
||||||
|
vec2 f = fract(p);
|
||||||
|
|
||||||
|
// Quintic interpolation (less banding than cubic)
|
||||||
|
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||||
|
|
||||||
|
uvec2 ui = uvec2(i);
|
||||||
|
float a = toGaussian(ui);
|
||||||
|
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||||
|
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||||
|
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||||
|
|
||||||
|
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
float smoothNoise(vec2 p, uint offset) {
|
||||||
|
vec2 i = floor(p);
|
||||||
|
vec2 f = fract(p);
|
||||||
|
|
||||||
|
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||||
|
|
||||||
|
uvec2 ui = uvec2(i);
|
||||||
|
float a = toGaussian(ui, offset);
|
||||||
|
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||||
|
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||||
|
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||||
|
|
||||||
|
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
// Luminance (Rec.709)
|
||||||
|
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||||
|
|
||||||
|
// Grain UV (resolution-independent)
|
||||||
|
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||||
|
uvec2 grainPixel = uvec2(grainUV);
|
||||||
|
|
||||||
|
float g;
|
||||||
|
vec3 grainRGB;
|
||||||
|
|
||||||
|
if (u_int0 == 1) {
|
||||||
|
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||||
|
g = toGaussian(grainPixel);
|
||||||
|
grainRGB = vec3(
|
||||||
|
toGaussian(grainPixel, 100u),
|
||||||
|
toGaussian(grainPixel, 200u),
|
||||||
|
toGaussian(grainPixel, 300u)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Smooth mode: interpolated with quintic curve
|
||||||
|
g = smoothNoise(grainUV);
|
||||||
|
grainRGB = vec3(
|
||||||
|
smoothNoise(grainUV, 100u),
|
||||||
|
smoothNoise(grainUV, 200u),
|
||||||
|
smoothNoise(grainUV, 300u)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Luminance weighting (less grain in highlights)
|
||||||
|
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||||
|
|
||||||
|
// Strength
|
||||||
|
float strength = u_float0 * 0.15;
|
||||||
|
|
||||||
|
// Color vs monochrome grain
|
||||||
|
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||||
|
|
||||||
|
color.rgb += grainColor * strength * lumWeight;
|
||||||
|
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||||
|
}
|
||||||
133
blueprints/.glsl/Glow_30.frag
Normal file
133
blueprints/.glsl/Glow_30.frag
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision mediump float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform int u_int0; // Blend mode
|
||||||
|
uniform int u_int1; // Color tint
|
||||||
|
uniform float u_float0; // Intensity
|
||||||
|
uniform float u_float1; // Radius
|
||||||
|
uniform float u_float2; // Threshold
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const int BLEND_ADD = 0;
|
||||||
|
const int BLEND_SCREEN = 1;
|
||||||
|
const int BLEND_SOFT = 2;
|
||||||
|
const int BLEND_OVERLAY = 3;
|
||||||
|
const int BLEND_LIGHTEN = 4;
|
||||||
|
|
||||||
|
const float GOLDEN_ANGLE = 2.39996323;
|
||||||
|
const int MAX_SAMPLES = 48;
|
||||||
|
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||||
|
|
||||||
|
float hash(vec2 p) {
|
||||||
|
p = fract(p * vec2(123.34, 456.21));
|
||||||
|
p += dot(p, p + 45.32);
|
||||||
|
return fract(p.x * p.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hexToRgb(int h) {
|
||||||
|
return vec3(
|
||||||
|
float((h >> 16) & 255),
|
||||||
|
float((h >> 8) & 255),
|
||||||
|
float(h & 255)
|
||||||
|
) * (1.0 / 255.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||||
|
if (mode == BLEND_SCREEN) {
|
||||||
|
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||||
|
}
|
||||||
|
if (mode == BLEND_SOFT) {
|
||||||
|
return mix(
|
||||||
|
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||||
|
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||||
|
step(0.5, glow)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (mode == BLEND_OVERLAY) {
|
||||||
|
return mix(
|
||||||
|
2.0 * base * glow,
|
||||||
|
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||||
|
step(0.5, base)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (mode == BLEND_LIGHTEN) {
|
||||||
|
return max(base, glow);
|
||||||
|
}
|
||||||
|
return base + glow;
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 original = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
float intensity = u_float0 * 0.05;
|
||||||
|
float radius = u_float1 * u_float1 * 0.012;
|
||||||
|
|
||||||
|
if (intensity < 0.001 || radius < 0.1) {
|
||||||
|
fragColor = original;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float threshold = 1.0 - u_float2 * 0.01;
|
||||||
|
float t0 = threshold - 0.15;
|
||||||
|
float t1 = threshold + 0.15;
|
||||||
|
|
||||||
|
vec2 texelSize = 1.0 / u_resolution;
|
||||||
|
float radius2 = radius * radius;
|
||||||
|
|
||||||
|
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||||
|
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||||
|
|
||||||
|
float noise = hash(gl_FragCoord.xy);
|
||||||
|
float angleOffset = noise * GOLDEN_ANGLE;
|
||||||
|
float radiusJitter = 0.85 + noise * 0.3;
|
||||||
|
|
||||||
|
float ca = cos(GOLDEN_ANGLE);
|
||||||
|
float sa = sin(GOLDEN_ANGLE);
|
||||||
|
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||||
|
|
||||||
|
vec3 glow = vec3(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
|
||||||
|
// Center tap
|
||||||
|
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||||
|
glow += original.rgb * centerMask * 2.0;
|
||||||
|
totalWeight += 2.0;
|
||||||
|
|
||||||
|
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||||
|
if (i >= samples) break;
|
||||||
|
|
||||||
|
float fi = float(i);
|
||||||
|
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||||
|
|
||||||
|
vec2 offset = dir * dist * texelSize;
|
||||||
|
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||||
|
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||||
|
|
||||||
|
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||||
|
w = max(w, 0.0);
|
||||||
|
w *= w;
|
||||||
|
|
||||||
|
glow += c * mask * w;
|
||||||
|
totalWeight += w;
|
||||||
|
|
||||||
|
dir = vec2(
|
||||||
|
dir.x * ca - dir.y * sa,
|
||||||
|
dir.x * sa + dir.y * ca
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
glow *= intensity / max(totalWeight, 0.001);
|
||||||
|
|
||||||
|
if (u_int1 > 0) {
|
||||||
|
glow *= hexToRgb(u_int1);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 result = blend(original.rgb, glow, u_int0);
|
||||||
|
result += (noise - 0.5) * (1.0 / 255.0);
|
||||||
|
|
||||||
|
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||||
|
}
|
||||||
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||||
|
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||||
|
uniform float u_float0; // Hue (-180 to 180)
|
||||||
|
uniform float u_float1; // Saturation (-100 to 100)
|
||||||
|
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||||
|
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
// Color range modes
|
||||||
|
const int MODE_MASTER = 0;
|
||||||
|
const int MODE_RED = 1;
|
||||||
|
const int MODE_YELLOW = 2;
|
||||||
|
const int MODE_GREEN = 3;
|
||||||
|
const int MODE_CYAN = 4;
|
||||||
|
const int MODE_BLUE = 5;
|
||||||
|
const int MODE_MAGENTA = 6;
|
||||||
|
const int MODE_COLORIZE = 7;
|
||||||
|
|
||||||
|
// Color space modes
|
||||||
|
const int COLORSPACE_HSL = 0;
|
||||||
|
const int COLORSPACE_HSB = 1;
|
||||||
|
|
||||||
|
const float EPSILON = 0.0001;
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// RGB <-> HSL Conversions
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
vec3 rgb2hsl(vec3 c) {
|
||||||
|
float maxC = max(max(c.r, c.g), c.b);
|
||||||
|
float minC = min(min(c.r, c.g), c.b);
|
||||||
|
float delta = maxC - minC;
|
||||||
|
|
||||||
|
float h = 0.0;
|
||||||
|
float s = 0.0;
|
||||||
|
float l = (maxC + minC) * 0.5;
|
||||||
|
|
||||||
|
if (delta > EPSILON) {
|
||||||
|
s = l < 0.5
|
||||||
|
? delta / (maxC + minC)
|
||||||
|
: delta / (2.0 - maxC - minC);
|
||||||
|
|
||||||
|
if (maxC == c.r) {
|
||||||
|
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||||
|
} else if (maxC == c.g) {
|
||||||
|
h = (c.b - c.r) / delta + 2.0;
|
||||||
|
} else {
|
||||||
|
h = (c.r - c.g) / delta + 4.0;
|
||||||
|
}
|
||||||
|
h /= 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return vec3(h, s, l);
|
||||||
|
}
|
||||||
|
|
||||||
|
float hue2rgb(float p, float q, float t) {
|
||||||
|
t = fract(t);
|
||||||
|
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||||
|
if (t < 0.5) return q;
|
||||||
|
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsl2rgb(vec3 hsl) {
|
||||||
|
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||||
|
|
||||||
|
float q = hsl.z < 0.5
|
||||||
|
? hsl.z * (1.0 + hsl.y)
|
||||||
|
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||||
|
float p = 2.0 * hsl.z - q;
|
||||||
|
|
||||||
|
return vec3(
|
||||||
|
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||||
|
hue2rgb(p, q, hsl.x),
|
||||||
|
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 rgb2hsb(vec3 c) {
|
||||||
|
float maxC = max(max(c.r, c.g), c.b);
|
||||||
|
float minC = min(min(c.r, c.g), c.b);
|
||||||
|
float delta = maxC - minC;
|
||||||
|
|
||||||
|
float h = 0.0;
|
||||||
|
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||||
|
float b = maxC;
|
||||||
|
|
||||||
|
if (delta > EPSILON) {
|
||||||
|
if (maxC == c.r) {
|
||||||
|
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||||
|
} else if (maxC == c.g) {
|
||||||
|
h = (c.b - c.r) / delta + 2.0;
|
||||||
|
} else {
|
||||||
|
h = (c.r - c.g) / delta + 4.0;
|
||||||
|
}
|
||||||
|
h /= 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return vec3(h, s, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsb2rgb(vec3 hsb) {
|
||||||
|
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||||
|
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Color Range Weight Calculation
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
float hueDistance(float a, float b) {
|
||||||
|
float d = abs(a - b);
|
||||||
|
return min(d, 1.0 - d);
|
||||||
|
}
|
||||||
|
|
||||||
|
float getHueWeight(float hue, float center, float overlap) {
|
||||||
|
float baseWidth = 1.0 / 6.0;
|
||||||
|
float feather = baseWidth * overlap;
|
||||||
|
|
||||||
|
float d = hueDistance(hue, center);
|
||||||
|
|
||||||
|
float inner = baseWidth * 0.5;
|
||||||
|
float outer = inner + feather;
|
||||||
|
|
||||||
|
return 1.0 - smoothstep(inner, outer, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
float getModeWeight(float hue, int mode, float overlap) {
|
||||||
|
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||||
|
|
||||||
|
if (mode == MODE_RED) {
|
||||||
|
return max(
|
||||||
|
getHueWeight(hue, 0.0, overlap),
|
||||||
|
getHueWeight(hue, 1.0, overlap)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
float center = float(mode - 1) / 6.0;
|
||||||
|
return getHueWeight(hue, center, overlap);
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Adjustment Functions
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
float adjustLightness(float l, float amount) {
|
||||||
|
return amount > 0.0
|
||||||
|
? l + (1.0 - l) * amount
|
||||||
|
: l + l * amount;
|
||||||
|
}
|
||||||
|
|
||||||
|
float adjustBrightness(float b, float amount) {
|
||||||
|
return clamp(b + amount, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
float adjustSaturation(float s, float amount) {
|
||||||
|
return amount > 0.0
|
||||||
|
? s + (1.0 - s) * amount
|
||||||
|
: s + s * amount;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||||
|
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||||
|
float l = adjustLightness(lum, light);
|
||||||
|
|
||||||
|
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||||
|
return hsl2rgb(hsl);
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Main
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 original = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||||
|
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||||
|
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||||
|
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||||
|
|
||||||
|
vec3 result;
|
||||||
|
|
||||||
|
if (u_int0 == MODE_COLORIZE) {
|
||||||
|
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||||
|
fragColor = vec4(result, original.a);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||||
|
? rgb2hsl(original.rgb)
|
||||||
|
: rgb2hsb(original.rgb);
|
||||||
|
|
||||||
|
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||||
|
|
||||||
|
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||||
|
weight = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weight > EPSILON) {
|
||||||
|
float h = fract(hsx.x + hueShift * weight);
|
||||||
|
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||||
|
float v = (u_int1 == COLORSPACE_HSL)
|
||||||
|
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||||
|
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||||
|
|
||||||
|
vec3 adjusted = vec3(h, s, v);
|
||||||
|
result = (u_int1 == COLORSPACE_HSL)
|
||||||
|
? hsl2rgb(adjusted)
|
||||||
|
: hsb2rgb(adjusted);
|
||||||
|
} else {
|
||||||
|
result = original.rgb;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(result, original.a);
|
||||||
|
}
|
||||||
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
#version 300 es
|
||||||
|
#pragma passes 2
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
// Blur type constants
|
||||||
|
const int BLUR_GAUSSIAN = 0;
|
||||||
|
const int BLUR_BOX = 1;
|
||||||
|
const int BLUR_RADIAL = 2;
|
||||||
|
|
||||||
|
// Radial blur config
|
||||||
|
const int RADIAL_SAMPLES = 12;
|
||||||
|
const float RADIAL_STRENGTH = 0.0003;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||||
|
uniform float u_float0; // Blur radius/amount
|
||||||
|
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
float gaussian(float x, float sigma) {
|
||||||
|
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texelSize = 1.0 / u_resolution;
|
||||||
|
float radius = max(u_float0, 0.0);
|
||||||
|
|
||||||
|
// Radial (angular) blur - single pass, doesn't use separable
|
||||||
|
if (u_int0 == BLUR_RADIAL) {
|
||||||
|
// Only execute on first pass
|
||||||
|
if (u_pass > 0) {
|
||||||
|
fragColor0 = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec2 center = vec2(0.5);
|
||||||
|
vec2 dir = v_texCoord - center;
|
||||||
|
float dist = length(dir);
|
||||||
|
|
||||||
|
if (dist < 1e-4) {
|
||||||
|
fragColor0 = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec4 sum = vec4(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
float angleStep = radius * RADIAL_STRENGTH;
|
||||||
|
|
||||||
|
dir /= dist;
|
||||||
|
|
||||||
|
float cosStep = cos(angleStep);
|
||||||
|
float sinStep = sin(angleStep);
|
||||||
|
|
||||||
|
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||||
|
vec2 rotDir = vec2(
|
||||||
|
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||||
|
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||||
|
);
|
||||||
|
|
||||||
|
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||||
|
vec2 uv = center + rotDir * dist;
|
||||||
|
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||||
|
sum += texture(u_image0, uv) * w;
|
||||||
|
totalWeight += w;
|
||||||
|
|
||||||
|
rotDir = vec2(
|
||||||
|
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||||
|
rotDir.x * sinStep + rotDir.y * cosStep
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor0 = sum / max(totalWeight, 0.001);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Separable Gaussian / Box blur
|
||||||
|
int samples = int(ceil(radius));
|
||||||
|
|
||||||
|
if (samples == 0) {
|
||||||
|
fragColor0 = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||||
|
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||||
|
|
||||||
|
vec4 color = vec4(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
float sigma = radius / 2.0;
|
||||||
|
|
||||||
|
for (int i = -samples; i <= samples; i++) {
|
||||||
|
vec2 offset = dir * float(i) * texelSize;
|
||||||
|
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||||
|
|
||||||
|
float weight;
|
||||||
|
if (u_int0 == BLUR_GAUSSIAN) {
|
||||||
|
weight = gaussian(float(i), sigma);
|
||||||
|
} else {
|
||||||
|
// BLUR_BOX
|
||||||
|
weight = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
color += sample_color * weight;
|
||||||
|
totalWeight += weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor0 = color / totalWeight;
|
||||||
|
}
|
||||||
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
layout(location = 1) out vec4 fragColor1;
|
||||||
|
layout(location = 2) out vec4 fragColor2;
|
||||||
|
layout(location = 3) out vec4 fragColor3;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = texture(u_image0, v_texCoord);
|
||||||
|
// Output each channel as grayscale to separate render targets
|
||||||
|
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||||
|
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||||
|
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||||
|
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||||
|
}
|
||||||
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
// Levels Adjustment
|
||||||
|
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||||
|
// u_float0: input black (0-255) default: 0
|
||||||
|
// u_float1: input white (0-255) default: 255
|
||||||
|
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||||
|
// u_float3: output black (0-255) default: 0
|
||||||
|
// u_float4: output white (0-255) default: 255
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform int u_int0;
|
||||||
|
uniform float u_float0;
|
||||||
|
uniform float u_float1;
|
||||||
|
uniform float u_float2;
|
||||||
|
uniform float u_float3;
|
||||||
|
uniform float u_float4;
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||||
|
float inRange = max(inWhite - inBlack, 0.0001);
|
||||||
|
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||||
|
result = pow(result, vec3(1.0 / gamma));
|
||||||
|
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||||
|
float inRange = max(inWhite - inBlack, 0.0001);
|
||||||
|
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||||
|
result = pow(result, 1.0 / gamma);
|
||||||
|
result = mix(outBlack, outWhite, result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 texColor = texture(u_image0, v_texCoord);
|
||||||
|
vec3 color = texColor.rgb;
|
||||||
|
|
||||||
|
float inBlack = u_float0 / 255.0;
|
||||||
|
float inWhite = u_float1 / 255.0;
|
||||||
|
float gamma = u_float2;
|
||||||
|
float outBlack = u_float3 / 255.0;
|
||||||
|
float outWhite = u_float4 / 255.0;
|
||||||
|
|
||||||
|
vec3 result;
|
||||||
|
|
||||||
|
if (u_int0 == 0) {
|
||||||
|
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else if (u_int0 == 1) {
|
||||||
|
result = color;
|
||||||
|
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else if (u_int0 == 2) {
|
||||||
|
result = color;
|
||||||
|
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else if (u_int0 == 3) {
|
||||||
|
result = color;
|
||||||
|
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
result = color;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(result, texColor.a);
|
||||||
|
}
|
||||||
28
blueprints/.glsl/README.md
Normal file
28
blueprints/.glsl/README.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# GLSL Shader Sources
|
||||||
|
|
||||||
|
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||||
|
|
||||||
|
## File Naming Convention
|
||||||
|
|
||||||
|
`{Blueprint_Name}_{node_id}.frag`
|
||||||
|
|
||||||
|
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||||
|
- **node_id**: The GLSLShader node ID within the subgraph
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Extract shaders from blueprint JSONs to this folder
|
||||||
|
python update_blueprints.py extract
|
||||||
|
|
||||||
|
# Patch edited shaders back into blueprint JSONs
|
||||||
|
python update_blueprints.py patch
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
1. Run `extract` to pull current shaders from JSONs
|
||||||
|
2. Edit `.frag` files
|
||||||
|
3. Run `patch` to update the blueprint JSONs
|
||||||
|
4. Test
|
||||||
|
5. Commit both `.frag` files and updated JSONs
|
||||||
28
blueprints/.glsl/Sharpen_23.frag
Normal file
28
blueprints/.glsl/Sharpen_23.frag
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texel = 1.0 / u_resolution;
|
||||||
|
|
||||||
|
// Sample center and neighbors
|
||||||
|
vec4 center = texture(u_image0, v_texCoord);
|
||||||
|
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||||
|
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||||
|
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||||
|
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||||
|
|
||||||
|
// Edge enhancement (Laplacian)
|
||||||
|
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||||
|
|
||||||
|
// Add edges back scaled by strength
|
||||||
|
vec4 sharpened = center + edges * u_float0;
|
||||||
|
|
||||||
|
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||||
|
}
|
||||||
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||||
|
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||||
|
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
float gaussian(float x, float sigma) {
|
||||||
|
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||||
|
}
|
||||||
|
|
||||||
|
float getLuminance(vec3 color) {
|
||||||
|
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texel = 1.0 / u_resolution;
|
||||||
|
float radius = max(u_float1, 0.5);
|
||||||
|
float amount = u_float0;
|
||||||
|
float threshold = u_float2;
|
||||||
|
|
||||||
|
vec4 original = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
// Gaussian blur for the "unsharp" mask
|
||||||
|
int samples = int(ceil(radius));
|
||||||
|
float sigma = radius / 2.0;
|
||||||
|
|
||||||
|
vec4 blurred = vec4(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
|
||||||
|
for (int x = -samples; x <= samples; x++) {
|
||||||
|
for (int y = -samples; y <= samples; y++) {
|
||||||
|
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||||
|
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||||
|
|
||||||
|
float dist = length(vec2(float(x), float(y)));
|
||||||
|
float weight = gaussian(dist, sigma);
|
||||||
|
blurred += sample_color * weight;
|
||||||
|
totalWeight += weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
blurred /= totalWeight;
|
||||||
|
|
||||||
|
// Unsharp mask = original - blurred
|
||||||
|
vec3 mask = original.rgb - blurred.rgb;
|
||||||
|
|
||||||
|
// Luminance-based threshold with smooth falloff
|
||||||
|
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||||
|
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||||
|
mask *= thresholdScale;
|
||||||
|
|
||||||
|
// Sharpen: original + mask * amount
|
||||||
|
vec3 sharpened = original.rgb + mask * amount;
|
||||||
|
|
||||||
|
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||||
|
}
|
||||||
159
blueprints/.glsl/update_blueprints.py
Normal file
159
blueprints/.glsl/update_blueprints.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Shader Blueprint Updater
|
||||||
|
|
||||||
|
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||||
|
|
||||||
|
File naming convention:
|
||||||
|
{Blueprint Name}_{node_id}.frag
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||||
|
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||||
|
python update_blueprints.py # Same as patch (default)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GLSL_DIR = Path(__file__).parent
|
||||||
|
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||||
|
|
||||||
|
|
||||||
|
def get_blueprint_files():
|
||||||
|
"""Get all blueprint JSON files."""
|
||||||
|
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_filename(name):
|
||||||
|
"""Convert blueprint name to safe filename."""
|
||||||
|
return re.sub(r'[^\w\-]', '_', name)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_shaders():
|
||||||
|
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||||
|
extracted = 0
|
||||||
|
for json_path in get_blueprint_files():
|
||||||
|
blueprint_name = json_path.stem
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError) as e:
|
||||||
|
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find GLSLShader nodes in subgraphs
|
||||||
|
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||||
|
for node in subgraph.get('nodes', []):
|
||||||
|
if node.get('type') == 'GLSLShader':
|
||||||
|
node_id = node.get('id')
|
||||||
|
widgets = node.get('widgets_values', [])
|
||||||
|
|
||||||
|
# Find shader code (first string that looks like GLSL)
|
||||||
|
for widget in widgets:
|
||||||
|
if isinstance(widget, str) and widget.startswith('#version'):
|
||||||
|
safe_name = sanitize_filename(blueprint_name)
|
||||||
|
frag_name = f"{safe_name}_{node_id}.frag"
|
||||||
|
frag_path = GLSL_DIR / frag_name
|
||||||
|
|
||||||
|
with open(frag_path, 'w') as f:
|
||||||
|
f.write(widget)
|
||||||
|
|
||||||
|
logger.info(" Extracted: %s", frag_name)
|
||||||
|
extracted += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("\nExtracted %d shader(s)", extracted)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_shaders():
|
||||||
|
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||||
|
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||||
|
shader_updates = {}
|
||||||
|
|
||||||
|
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||||
|
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||||
|
parts = frag_path.stem.rsplit('_', 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
blueprint_name, node_id_str = parts
|
||||||
|
|
||||||
|
try:
|
||||||
|
node_id = int(node_id_str)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(frag_path, 'r') as f:
|
||||||
|
shader_code = f.read()
|
||||||
|
|
||||||
|
if blueprint_name not in shader_updates:
|
||||||
|
shader_updates[blueprint_name] = []
|
||||||
|
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||||
|
|
||||||
|
# Apply updates to JSON files
|
||||||
|
patched = 0
|
||||||
|
for json_path in get_blueprint_files():
|
||||||
|
blueprint_name = sanitize_filename(json_path.stem)
|
||||||
|
|
||||||
|
if blueprint_name not in shader_updates:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError) as e:
|
||||||
|
logger.error("Error reading %s: %s", json_path.name, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
modified = False
|
||||||
|
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||||
|
# Find the node and update
|
||||||
|
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||||
|
for node in subgraph.get('nodes', []):
|
||||||
|
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||||
|
widgets = node.get('widgets_values', [])
|
||||||
|
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||||
|
widgets[0] = shader_code
|
||||||
|
modified = True
|
||||||
|
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||||
|
patched += 1
|
||||||
|
|
||||||
|
if modified:
|
||||||
|
with open(json_path, 'w') as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
if patched == 0:
|
||||||
|
logger.info("No changes to apply.")
|
||||||
|
else:
|
||||||
|
logger.info("\nPatched %d shader(s)", patched)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
command = "patch"
|
||||||
|
else:
|
||||||
|
command = sys.argv[1].lower()
|
||||||
|
|
||||||
|
if command == "extract":
|
||||||
|
logger.info("Extracting shaders from blueprints...")
|
||||||
|
extract_shaders()
|
||||||
|
elif command in ("patch", "update", "apply"):
|
||||||
|
logger.info("Patching shaders into blueprints...")
|
||||||
|
patch_shaders()
|
||||||
|
else:
|
||||||
|
logger.info(__doc__)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
blueprints/Brightness and Contrast.json
Normal file
1
blueprints/Brightness and Contrast.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Chromatic Aberration.json
Normal file
1
blueprints/Chromatic Aberration.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Adjustment.json
Normal file
1
blueprints/Color Adjustment.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Edge-Preserving Blur.json
Normal file
1
blueprints/Edge-Preserving Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Film Grain.json
Normal file
1
blueprints/Film Grain.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Glow.json
Normal file
1
blueprints/Glow.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Hue and Saturation.json
Normal file
1
blueprints/Hue and Saturation.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Blur.json
Normal file
1
blueprints/Image Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Captioning (gemini).json
Normal file
1
blueprints/Image Captioning (gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Channels.json
Normal file
1
blueprints/Image Channels.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust", "essentials_category": "Image Tools"}]}}
|
||||||
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Edit (Qwen 2511).json
Normal file
1
blueprints/Image Edit (Qwen 2511).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Levels.json
Normal file
1
blueprints/Image Levels.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Depth Map (Lotus).json
Normal file
1
blueprints/Image to Depth Map (Lotus).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Video (Wan 2.2).json
Normal file
1
blueprints/Image to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Prompt Enhance.json
Normal file
1
blueprints/Prompt Enhance.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance", "essentials_category": "Text Generation"}]}, "extra": {}}
|
||||||
1
blueprints/Sharpen.json
Normal file
1
blueprints/Sharpen.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen", "essentials_category": "Image Tools"}]}}
|
||||||
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Video (Wan 2.2).json
Normal file
1
blueprints/Text to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Unsharp Mask.json
Normal file
1
blueprints/Unsharp Mask.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Captioning (Gemini).json
Normal file
1
blueprints/Video Captioning (Gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Stitch.json
Normal file
1
blueprints/Video Stitch.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Upscale(GAN x4).json
Normal file
1
blueprints/Video Upscale(GAN x4).json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video", "essentials_category": "Video Generation"}]}, "extra": {}}
|
||||||
@@ -25,11 +25,11 @@ class AudioEncoderModel():
|
|||||||
elif model_type == "whisper3":
|
elif model_type == "whisper3":
|
||||||
self.model = WhisperLargeV3(**model_config)
|
self.model = WhisperLargeV3(**model_config)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
self.model_sample_rate = 16000
|
self.model_sample_rate = 16000
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
import pickle
|
|
||||||
|
|
||||||
load = pickle.load
|
|
||||||
|
|
||||||
class Empty:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Unpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
#TODO: safe unpickle
|
|
||||||
if module.startswith("pytorch_lightning"):
|
|
||||||
return Empty
|
|
||||||
return super().find_class(module, name)
|
|
||||||
@@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
|
DynamicVRAM = "dynamic_vram"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
@@ -257,3 +258,6 @@ elif args.fast == []:
|
|||||||
# '--fast' is provided with a list of performance features, use that list
|
# '--fast' is provided with a list of performance features, use that list
|
||||||
else:
|
else:
|
||||||
args.fast = set(args.fast)
|
args.fast = set(args.fast)
|
||||||
|
|
||||||
|
def enables_dynamic_vram():
|
||||||
|
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||||
|
|||||||
@@ -47,10 +47,10 @@ class ClipVisionModel():
|
|||||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|||||||
@@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||||
Available after ComfyUI frontend v1.13.4
|
Available after ComfyUI frontend v1.13.4
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||||
|
gradient_stops: NotRequired[list[list[float]]]
|
||||||
|
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
|||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
if control_model is not None:
|
if control_model is not None:
|
||||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||||
|
|
||||||
self.compression_ratio = compression_ratio
|
self.compression_ratio = compression_ratio
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
super().cleanup()
|
super().cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenFunControlNet(ControlNet):
|
||||||
|
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||||
|
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||||
|
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||||
|
# unchanged while >1 grows more gently.
|
||||||
|
original_strength = self.strength
|
||||||
|
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||||
|
try:
|
||||||
|
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||||
|
finally:
|
||||||
|
self.strength = original_strength
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.set_extra_arg("base_model", model.diffusion_model)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
|||||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
sd = model_config.process_unet_state_dict(sd)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
in_features = sd["control_img_in.weight"].shape[1]
|
||||||
|
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||||
|
|
||||||
|
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||||
|
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||||
|
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||||
|
|
||||||
|
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||||
|
control_in_features=in_features,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
operations=operations,
|
||||||
|
device=comfy.model_management.unet_offload_device(),
|
||||||
|
dtype=unet_dtype,
|
||||||
|
)
|
||||||
|
model = controlnet_load_state_dict(model, sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
control = QwenFunControlNet(
|
||||||
|
model,
|
||||||
|
compression_ratio=1,
|
||||||
|
latent_format=latent_format,
|
||||||
|
# Fun checkpoints already expect their own 33-channel context handling.
|
||||||
|
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||||
|
# and breaks the intended fallback packing path.
|
||||||
|
concat_mask=False,
|
||||||
|
load_device=load_device,
|
||||||
|
manual_cast_dtype=manual_cast_dtype,
|
||||||
|
extra_conds=[],
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from scipy import integrate
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torchsde
|
import torchsde
|
||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
@@ -13,6 +13,9 @@ from . import sa_solver
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
|
import comfy.memory_management
|
||||||
|
from comfy.utils import model_trange as trange
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ class SD_X4(LatentFormat):
|
|||||||
|
|
||||||
class SC_Prior(LatentFormat):
|
class SC_Prior(LatentFormat):
|
||||||
latent_channels = 16
|
latent_channels = 16
|
||||||
|
spacial_downscale_ratio = 42
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 1.0
|
self.scale_factor = 1.0
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
@@ -103,6 +104,7 @@ class SC_Prior(LatentFormat):
|
|||||||
]
|
]
|
||||||
|
|
||||||
class SC_B(LatentFormat):
|
class SC_B(LatentFormat):
|
||||||
|
spacial_downscale_ratio = 4
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 1.0 / 0.43
|
self.scale_factor = 1.0 / 0.43
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
@@ -274,6 +276,7 @@ class Mochi(LatentFormat):
|
|||||||
class LTXV(LatentFormat):
|
class LTXV(LatentFormat):
|
||||||
latent_channels = 128
|
latent_channels = 128
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
|
spacial_downscale_ratio = 32
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
@@ -517,6 +520,7 @@ class Wan21(LatentFormat):
|
|||||||
class Wan22(Wan21):
|
class Wan22(Wan21):
|
||||||
latent_channels = 48
|
latent_channels = 48
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
|
spacial_downscale_ratio = 16
|
||||||
|
|
||||||
latent_rgb_factors = [
|
latent_rgb_factors = [
|
||||||
[ 0.0119, 0.0103, 0.0046],
|
[ 0.0119, 0.0103, 0.0046],
|
||||||
@@ -751,6 +755,10 @@ class ACEAudio(LatentFormat):
|
|||||||
latent_channels = 8
|
latent_channels = 8
|
||||||
latent_dimensions = 2
|
latent_dimensions = 2
|
||||||
|
|
||||||
|
class ACEAudio15(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 1
|
||||||
|
|
||||||
class ChromaRadiance(LatentFormat):
|
class ChromaRadiance(LatentFormat):
|
||||||
latent_channels = 3
|
latent_channels = 3
|
||||||
spacial_downscale_ratio = 1
|
spacial_downscale_ratio = 1
|
||||||
|
|||||||
1155
comfy/ldm/ace/ace_step15.py
Normal file
1155
comfy/ldm/ace/ace_step15.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
|||||||
if source_attention_mask.ndim == 2:
|
if source_attention_mask.ndim == 2:
|
||||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||||
|
|
||||||
x = self.in_proj(self.embed(target_input_ids))
|
|
||||||
context = source_hidden_states
|
context = source_hidden_states
|
||||||
|
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_embeddings = self.rotary_emb(x, position_ids)
|
position_embeddings = self.rotary_emb(x, position_ids)
|
||||||
@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||||
|
|
||||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||||
if text_ids is not None:
|
if text_ids is not None:
|
||||||
return self.llm_adapter(text_embeds, text_ids)
|
out = self.llm_adapter(text_embeds, text_ids)
|
||||||
|
if t5xxl_weights is not None:
|
||||||
|
out = out * t5xxl_weights
|
||||||
|
|
||||||
|
if out.shape[1] < 512:
|
||||||
|
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||||
|
return out
|
||||||
else:
|
else:
|
||||||
return text_embeds
|
return text_embeds
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, **kwargs):
|
||||||
|
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||||
|
if t5xxl_ids is not None:
|
||||||
|
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||||
|
return super().forward(x, timesteps, context, **kwargs)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
RMSNorm,
|
|
||||||
ModulationOut,
|
ModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
|||||||
transformer_options={},
|
transformer_options={},
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from functools import lru_cache
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class NerfEmbedder(nn.Module):
|
class NerfEmbedder(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
# We now need to generate parameters for 3 matrices.
|
# We now need to generate parameters for 3 matrices.
|
||||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
class NerfFinalLayer(nn.Module):
|
class NerfFinalLayer(nn.Module):
|
||||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
|||||||
class NerfFinalLayerConv(nn.Module):
|
class NerfFinalLayerConv(nn.Module):
|
||||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.conv = operations.Conv2d(
|
self.conv = operations.Conv2d(
|
||||||
in_channels=hidden_size,
|
in_channels=hidden_size,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from torchvision import transforms
|
|||||||
|
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
@@ -334,7 +335,7 @@ class FinalLayer(nn.Module):
|
|||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
self.linear = operations.Linear(
|
self.linear = operations.Linear(
|
||||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
@@ -462,6 +463,8 @@ class Block(nn.Module):
|
|||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
transformer_options: Optional[dict] = {},
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
residual_dtype = x_B_T_H_W_D.dtype
|
||||||
|
compute_dtype = emb_B_T_D.dtype
|
||||||
if extra_per_block_pos_emb is not None:
|
if extra_per_block_pos_emb is not None:
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||||
|
|
||||||
@@ -511,7 +514,7 @@ class Block(nn.Module):
|
|||||||
result_B_T_H_W_D = rearrange(
|
result_B_T_H_W_D = rearrange(
|
||||||
self.self_attn(
|
self.self_attn(
|
||||||
# normalized_x_B_T_HW_D,
|
# normalized_x_B_T_HW_D,
|
||||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||||
None,
|
None,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
@@ -521,7 +524,7 @@ class Block(nn.Module):
|
|||||||
h=H,
|
h=H,
|
||||||
w=W,
|
w=W,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||||
|
|
||||||
def _x_fn(
|
def _x_fn(
|
||||||
_x_B_T_H_W_D: torch.Tensor,
|
_x_B_T_H_W_D: torch.Tensor,
|
||||||
@@ -535,7 +538,7 @@ class Block(nn.Module):
|
|||||||
)
|
)
|
||||||
_result_B_T_H_W_D = rearrange(
|
_result_B_T_H_W_D = rearrange(
|
||||||
self.cross_attn(
|
self.cross_attn(
|
||||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
@@ -554,7 +557,7 @@ class Block(nn.Module):
|
|||||||
shift_cross_attn_B_T_1_1_D,
|
shift_cross_attn_B_T_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||||
|
|
||||||
normalized_x_B_T_H_W_D = _fn(
|
normalized_x_B_T_H_W_D = _fn(
|
||||||
x_B_T_H_W_D,
|
x_B_T_H_W_D,
|
||||||
@@ -562,8 +565,8 @@ class Block(nn.Module):
|
|||||||
scale_mlp_B_T_1_1_D,
|
scale_mlp_B_T_1_1_D,
|
||||||
shift_mlp_B_T_1_1_D,
|
shift_mlp_B_T_1_1_D,
|
||||||
)
|
)
|
||||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||||
return x_B_T_H_W_D
|
return x_B_T_H_W_D
|
||||||
|
|
||||||
|
|
||||||
@@ -835,6 +838,8 @@ class MiniTrainDIT(nn.Module):
|
|||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
orig_shape = list(x.shape)
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
||||||
x_B_C_T_H_W = x
|
x_B_C_T_H_W = x
|
||||||
timesteps_B_T = timesteps
|
timesteps_B_T = timesteps
|
||||||
crossattn_emb = context
|
crossattn_emb = context
|
||||||
@@ -873,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
|||||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
"transformer_options": kwargs.get("transformer_options", {}),
|
"transformer_options": kwargs.get("transformer_options", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||||
|
# in fp32, but run attention and MLP modules in fp16.
|
||||||
|
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||||
|
# quality degradation and visual artifacts.
|
||||||
|
if x_B_T_H_W_D.dtype == torch.float16:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x_B_T_H_W_D = block(
|
x_B_T_H_W_D = block(
|
||||||
x_B_T_H_W_D,
|
x_B_T_H_W_D,
|
||||||
@@ -881,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
|||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||||
return x_B_C_Tt_Hp_Wp
|
return x_B_C_Tt_Hp_Wp
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
# Fix import for some custom nodes, TODO: delete eventually.
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||||
q = self.query_norm(q)
|
q = self.query_norm(q)
|
||||||
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
@@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
@@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
del txt_qkv
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
q = torch.cat((img_q, txt_q), dim=2)
|
del txt_q, img_q
|
||||||
del img_q, txt_q
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
k = torch.cat((img_k, txt_k), dim=2)
|
del txt_k, img_k
|
||||||
del img_k, txt_k
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
v = torch.cat((img_v, txt_v), dim=2)
|
del txt_v, img_v
|
||||||
del img_v, txt_v
|
# run actual attention
|
||||||
# run actual attention
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
attn = attention(q, k, v,
|
del q, k, v
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
if "attn1_output_patch" in transformer_patches:
|
||||||
else:
|
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
del txt_q, img_q
|
for p in patch:
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
attn = p(attn, extra_options)
|
||||||
del txt_k, img_k
|
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
|
||||||
del txt_v, img_v
|
|
||||||
# run actual attention
|
|
||||||
attn = attention(q, k, v,
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
@@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
mod = vec
|
mod = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
@@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
attn = p(attn, extra_options)
|
||||||
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
if self.yak_mlp:
|
if self.yak_mlp:
|
||||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||||
|
|||||||
@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
|
||||||
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|
||||||
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope(xq, xk, freqs_cis)
|
||||||
|
else:
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
def apply_rope1(x, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope1(x, freqs_cis)
|
||||||
|
else:
|
||||||
|
return q_apply_rope1(x, freqs_cis)
|
||||||
except:
|
except:
|
||||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
apply_rope = _apply_rope
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
apply_rope1 = _apply_rope1
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
|
||||||
|
|
||||||
return x_out.reshape(*x.shape).type_as(x)
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|
||||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from .layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
Modulation,
|
Modulation,
|
||||||
RMSNorm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -81,7 +80,7 @@ class Flux(nn.Module):
|
|||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
if params.txt_norm:
|
if params.txt_norm:
|
||||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.txt_norm = None
|
self.txt_norm = None
|
||||||
|
|
||||||
@@ -143,6 +142,7 @@ class Flux(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@@ -232,6 +232,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
flipped_img_txt=True,
|
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
initial_shape = list(img.shape)
|
initial_shape = list(img.shape)
|
||||||
@@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||||
|
|
||||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
img_len = img.shape[1]
|
img_len = img.shape[1]
|
||||||
if txt_mask is not None:
|
if txt_mask is not None:
|
||||||
attn_mask_len = img_len + txt.shape[1]
|
attn_mask_len = img_len + txt.shape[1]
|
||||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||||
attn_mask[:, 0, img_len:] = txt_mask
|
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
@@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img += add
|
||||||
|
|
||||||
img = torch.cat((img, txt), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
@@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, : img_len] += add
|
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||||
|
|
||||||
img = img[:, : img_len]
|
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
img = img[:, ref_latent.shape[1]:]
|
img = img[:, ref_latent.shape[1]:]
|
||||||
|
|
||||||
|
|||||||
@@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
|
|||||||
self.model_class = UPSAMPLERS.get(model_type)
|
self.model_class = UPSAMPLERS.get(model_type)
|
||||||
self.model = self.model_class(**config).eval()
|
self.model = self.model_class(**config).eval()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.model.load_state_dict(sd, strict=True)
|
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
|
|||||||
LTXVModel,
|
LTXVModel,
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
class CompressedTimestep:
|
class CompressedTimestep:
|
||||||
@@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
|
|||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess_text_embeds(self, context):
|
||||||
|
if context.shape[-1] == self.caption_channels * 2:
|
||||||
|
return context
|
||||||
|
out_vid = self.video_embeddings_connector(context)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(context)[0]
|
||||||
|
return torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
"""Initialize transformer blocks for LTXAV."""
|
"""Initialize transformer blocks for LTXAV."""
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
return indices
|
return indices
|
||||||
|
|
||||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||||
dim = self.inner_dim
|
dim = self.inner_dim
|
||||||
n_elem = 2 # 2 because of cos and sin
|
n_elem = 2 # 2 because of cos and sin
|
||||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||||
@@ -247,7 +247,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -288,7 +288,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
indices_grid = indices_grid[None, None, :]
|
indices_grid = indices_grid[None, None, :]
|
||||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||||
|
|
||||||
# 2. Blocks
|
# 2. Blocks
|
||||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||||
|
|||||||
@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if kwargs.get("low_precision_attention", True) is False:
|
||||||
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||||
|
|
||||||
exception_fallback = False
|
exception_fallback = False
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
|
|||||||
@@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
|
|||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
def interpolate_up(x, scale_factor):
|
def interpolate_up(x, scale_factor):
|
||||||
try:
|
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
|
||||||
except: #operation not implemented for bf16
|
|
||||||
orig_shape = list(x.shape)
|
|
||||||
out_shape = orig_shape[:2]
|
|
||||||
for i in range(len(orig_shape) - 2):
|
|
||||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
|
||||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
|
||||||
split = 8
|
|
||||||
l = out.shape[1] // split
|
|
||||||
for i in range(0, out.shape[1], l):
|
|
||||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||||
|
|||||||
@@ -2,6 +2,196 @@ import torch
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from .model import QwenImageTransformer2DModel
|
from .model import QwenImageTransformer2DModel
|
||||||
|
from .model import QwenImageTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||||
|
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.has_before_proj = has_before_proj
|
||||||
|
if has_before_proj:
|
||||||
|
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
control_in_features=132,
|
||||||
|
inner_dim=3072,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.main_model_double = main_model_double
|
||||||
|
self.injection_layers = tuple(injection_layers)
|
||||||
|
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||||
|
# to the reference Gen2/VideoX implementation around strength=1.
|
||||||
|
self.hint_scale = 1.0
|
||||||
|
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.control_blocks = torch.nn.ModuleList([])
|
||||||
|
for i in range(num_control_blocks):
|
||||||
|
self.control_blocks.append(
|
||||||
|
QwenImageFunControlBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
has_before_proj=(i == 0),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_hint_tokens(self, hint):
|
||||||
|
if hint is None:
|
||||||
|
return None
|
||||||
|
if hint.ndim == 4:
|
||||||
|
hint = hint.unsqueeze(2)
|
||||||
|
|
||||||
|
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||||
|
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||||
|
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||||
|
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||||
|
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||||
|
if hint.shape[1] == 16 and expected_c == 33:
|
||||||
|
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||||
|
zeros_inpaint = torch.zeros_like(hint)
|
||||||
|
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||||
|
|
||||||
|
bs, c, t, h, w = hint.shape
|
||||||
|
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
orig_shape[0],
|
||||||
|
orig_shape[1],
|
||||||
|
orig_shape[-3],
|
||||||
|
orig_shape[-2] // 2,
|
||||||
|
2,
|
||||||
|
orig_shape[-1] // 2,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
bs,
|
||||||
|
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||||
|
c * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_in = self.control_img_in.weight.shape[1]
|
||||||
|
cur_in = hidden_states.shape[-1]
|
||||||
|
if cur_in < expected_in:
|
||||||
|
pad = torch.zeros(
|
||||||
|
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||||
|
elif cur_in > expected_in:
|
||||||
|
hidden_states = hidden_states[:, :, :expected_in]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
base_model=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if base_model is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||||
|
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||||
|
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||||
|
encoder_hidden_states_mask = None
|
||||||
|
|
||||||
|
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||||
|
hint_tokens = self._process_hint_tokens(hint)
|
||||||
|
if hint_tokens is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||||
|
|
||||||
|
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||||
|
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||||
|
hint_tokens = hint_tokens[:, :max_tokens]
|
||||||
|
hidden_states = hidden_states[:, :max_tokens]
|
||||||
|
img_ids = img_ids[:, :max_tokens]
|
||||||
|
|
||||||
|
txt_start = round(
|
||||||
|
max(
|
||||||
|
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
|
||||||
|
hidden_states = base_model.img_in(hidden_states)
|
||||||
|
encoder_hidden_states = base_model.txt_norm(context)
|
||||||
|
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
base_model.time_text_embed(timesteps, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
c = self.control_img_in(hint_tokens)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.control_blocks):
|
||||||
|
if i == 0:
|
||||||
|
c_in = block.before_proj(c) + hidden_states
|
||||||
|
all_c = []
|
||||||
|
else:
|
||||||
|
all_c = list(torch.unbind(c, dim=0))
|
||||||
|
c_in = all_c.pop(-1)
|
||||||
|
|
||||||
|
encoder_hidden_states, c_out = block(
|
||||||
|
hidden_states=c_in,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||||
|
all_c += [c_skip, c_out]
|
||||||
|
c = torch.stack(all_c, dim=0)
|
||||||
|
|
||||||
|
hints = torch.unbind(c, dim=0)[:-1]
|
||||||
|
|
||||||
|
controlnet_block_samples = [None] * self.main_model_double
|
||||||
|
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||||
|
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||||
|
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
|||||||
@@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["{}".format(key_lora)] = k
|
key_map["{}".format(key_lora)] = k
|
||||||
key_map["transformer.{}".format(key_lora)] = k
|
key_map["transformer.{}".format(key_lora)] = k
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.ACEStep15):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||||
|
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
@@ -368,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
|
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
def calculate_shape(patches, weight, key, original_weights=None):
|
||||||
|
current_shape = weight.shape
|
||||||
|
|
||||||
|
for p in patches:
|
||||||
|
v = p[1]
|
||||||
|
offset = p[3]
|
||||||
|
|
||||||
|
# Offsets restore the old shape; lists force a diff without metadata
|
||||||
|
if offset is not None or isinstance(v, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
adapter_shape = v.calculate_shape(key)
|
||||||
|
if adapter_shape is not None:
|
||||||
|
current_shape = adapter_shape
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Standard diff logic with padding
|
||||||
|
if len(v) == 2:
|
||||||
|
patch_type, patch_data = v[0], v[1]
|
||||||
|
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||||
|
current_shape = patch_data[0].shape
|
||||||
|
|
||||||
|
return current_shape
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import comfy.utils
|
|||||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||||
sd_out = {}
|
sd_out = {}
|
||||||
for k in sd:
|
for k in sd:
|
||||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||||
sd_out[k_to] = sd[k]
|
sd_out[k_to] = sd[k]
|
||||||
|
|
||||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||||
|
|||||||
81
comfy/memory_management.py
Normal file
81
comfy/memory_management.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
class TensorGeometry(NamedTuple):
|
||||||
|
shape: any
|
||||||
|
dtype: torch.dtype
|
||||||
|
|
||||||
|
def element_size(self):
|
||||||
|
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
|
||||||
|
return info.bits // 8
|
||||||
|
|
||||||
|
def numel(self):
|
||||||
|
return math.prod(self.shape)
|
||||||
|
|
||||||
|
def tensors_to_geometries(tensors, dtype=None):
|
||||||
|
geometries = []
|
||||||
|
for t in tensors:
|
||||||
|
if t is None or isinstance(t, QuantizedTensor):
|
||||||
|
geometries.append(t)
|
||||||
|
continue
|
||||||
|
tdtype = t.dtype
|
||||||
|
if hasattr(t, "_model_dtype"):
|
||||||
|
tdtype = t._model_dtype
|
||||||
|
if dtype is not None:
|
||||||
|
tdtype = dtype
|
||||||
|
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
|
||||||
|
return geometries
|
||||||
|
|
||||||
|
def vram_aligned_size(tensor):
|
||||||
|
if isinstance(tensor, list):
|
||||||
|
return sum([vram_aligned_size(t) for t in tensor])
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||||
|
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
size = tensor.numel() * tensor.element_size()
|
||||||
|
aligment_req = 1024
|
||||||
|
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||||
|
|
||||||
|
def interpret_gathered_like(tensors, gathered):
|
||||||
|
offset = 0
|
||||||
|
dest_views = []
|
||||||
|
|
||||||
|
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||||
|
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||||
|
|
||||||
|
for tensor in tensors:
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
dest_views.append(None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||||
|
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||||
|
else:
|
||||||
|
templates = { "data": tensor }
|
||||||
|
|
||||||
|
actuals = {}
|
||||||
|
for attr, template in templates.items():
|
||||||
|
size = template.numel() * template.element_size()
|
||||||
|
if offset + size > gathered.numel():
|
||||||
|
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||||
|
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||||
|
offset += vram_aligned_size(template)
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||||
|
else:
|
||||||
|
dest_views.append(actuals["data"])
|
||||||
|
|
||||||
|
return dest_views
|
||||||
|
|
||||||
|
aimdo_enabled = False
|
||||||
@@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
|
|||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
import comfy.ldm.kandinsky5.model
|
import comfy.ldm.kandinsky5.model
|
||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
|
import comfy.ldm.ace.ace_step15
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@@ -146,6 +147,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
logging.debug("using channels last mode for diffusion model")
|
logging.debug("using channels last mode for diffusion model")
|
||||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||||
|
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||||
|
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
@@ -175,10 +178,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
device = xc.device
|
device = xc.device
|
||||||
@@ -215,6 +215,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
|
def get_dtype_inference(self):
|
||||||
|
dtype = self.get_dtype()
|
||||||
|
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
return dtype
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -299,7 +306,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix="", assign=False):
|
||||||
to_load = {}
|
to_load = {}
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
@@ -307,7 +314,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||||
|
|
||||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
logging.warning("unet missing: {}".format(m))
|
logging.warning("unet missing: {}".format(m))
|
||||||
|
|
||||||
@@ -322,7 +329,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
def process_latent_out(self, latent):
|
def process_latent_out(self, latent):
|
||||||
return self.latent_format.process_out(latent)
|
return self.latent_format.process_out(latent)
|
||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||||
extra_sds = []
|
extra_sds = []
|
||||||
if clip_state_dict is not None:
|
if clip_state_dict is not None:
|
||||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||||
@@ -330,10 +337,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||||
if clip_vision_state_dict is not None:
|
if clip_vision_state_dict is not None:
|
||||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
unet_state_dict["v_pred"] = torch.tensor([])
|
unet_state_dict["v_pred"] = torch.tensor([])
|
||||||
|
|
||||||
@@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes += shape
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
@@ -776,8 +778,8 @@ class StableAudio1(BaseModel):
|
|||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||||
for k in d:
|
for k in d:
|
||||||
s = d[k]
|
s = d[k]
|
||||||
@@ -986,10 +988,14 @@ class LTXAV(BaseModel):
|
|||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
|
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||||
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
@@ -1160,12 +1166,16 @@ class Anima(BaseModel):
|
|||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
if t5xxl_ids is not None:
|
if t5xxl_ids is not None:
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
|
||||||
if t5xxl_weights is not None:
|
if t5xxl_weights is not None:
|
||||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||||
|
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if torch.is_inference_mode_enabled(): # if not we are training
|
||||||
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||||
|
else:
|
||||||
|
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||||
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
|
|
||||||
if cross_attn.shape[1] < 512:
|
|
||||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -1541,6 +1551,49 @@ class ACEStep(BaseModel):
|
|||||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class ACEStep15(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
device = kwargs["device"]
|
||||||
|
noise = kwargs["noise"]
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
if torch.count_nonzero(cross_attn) == 0:
|
||||||
|
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||||
|
|
||||||
|
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
|
||||||
|
if refer_audio is None or len(refer_audio) == 0:
|
||||||
|
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||||
|
pass_audio_codes = True
|
||||||
|
else:
|
||||||
|
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||||
|
out['is_covers'] = comfy.conds.CONDConstant(True)
|
||||||
|
pass_audio_codes = False
|
||||||
|
|
||||||
|
if pass_audio_codes:
|
||||||
|
audio_codes = kwargs.get("audio_codes", None)
|
||||||
|
if audio_codes is not None:
|
||||||
|
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||||
|
refer_audio = refer_audio[:, :, :750]
|
||||||
|
else:
|
||||||
|
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||||
|
|
||||||
|
if refer_audio.shape[2] < noise.shape[2]:
|
||||||
|
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||||
|
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||||
|
|
||||||
|
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||||
|
return out
|
||||||
|
|
||||||
class Omnigen2(BaseModel):
|
class Omnigen2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||||
|
|||||||
@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||||
|
for x in suffix_list:
|
||||||
|
if "{}{}{}".format(prefix, main, x) in keys:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||||
context_dim = None
|
context_dim = None
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["meanflow_sum"] = False
|
dit_config["meanflow_sum"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["image_model"] = "flux2"
|
dit_config["image_model"] = "flux2"
|
||||||
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||||
dit_config["image_model"] = "chroma"
|
dit_config["image_model"] = "chroma"
|
||||||
dit_config["in_channels"] = 64
|
dit_config["in_channels"] = 64
|
||||||
dit_config["out_channels"] = 64
|
dit_config["out_channels"] = 64
|
||||||
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["out_dim"] = 3072
|
dit_config["out_dim"] = 3072
|
||||||
dit_config["hidden_dim"] = 5120
|
dit_config["hidden_dim"] = 5120
|
||||||
dit_config["n_layers"] = 5
|
dit_config["n_layers"] = 5
|
||||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||||
dit_config["image_model"] = "chroma_radiance"
|
dit_config["image_model"] = "chroma_radiance"
|
||||||
dit_config["in_channels"] = 3
|
dit_config["in_channels"] = 3
|
||||||
dit_config["out_channels"] = 3
|
dit_config["out_channels"] = 3
|
||||||
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
@@ -655,6 +663,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["audio_model"] = "ace1.5"
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -20,12 +20,20 @@ import psutil
|
|||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import platform
|
import platform
|
||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
|
import comfy.memory_management
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.quant_ops
|
||||||
|
|
||||||
|
import comfy_aimdo.torch
|
||||||
|
import comfy_aimdo.model_vbar
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
@@ -47,6 +55,11 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Training Related State
|
||||||
|
in_training = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
float8_types = []
|
float8_types = []
|
||||||
try:
|
try:
|
||||||
@@ -578,9 +591,15 @@ WINDOWS = any(platform.win32_ver())
|
|||||||
|
|
||||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
|
import comfy.windows
|
||||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||||
|
def get_free_ram():
|
||||||
|
return comfy.windows.get_free_ram()
|
||||||
|
else:
|
||||||
|
def get_free_ram():
|
||||||
|
return psutil.virtual_memory().available
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
@@ -592,7 +611,7 @@ def extra_reserved_memory():
|
|||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
@@ -607,15 +626,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = None
|
memory_to_free = 1e32
|
||||||
|
ram_to_free = 1e32
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
free_mem = get_free_memory(device)
|
memory_to_free = memory_required - get_free_memory(device)
|
||||||
if free_mem > memory_required:
|
ram_to_free = ram_required - get_free_ram()
|
||||||
break
|
|
||||||
memory_to_free = memory_required - free_mem
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
if current_loaded_models[i].model_unload(memory_to_free):
|
#as that works on-demand.
|
||||||
|
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||||
|
memory_to_free = 0
|
||||||
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
|
if ram_to_free > 0:
|
||||||
|
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
@@ -650,7 +677,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
models_to_load = []
|
models_to_load = []
|
||||||
|
|
||||||
|
free_for_dynamic=True
|
||||||
for x in models:
|
for x in models:
|
||||||
|
if not x.is_dynamic():
|
||||||
|
free_for_dynamic = False
|
||||||
loaded_model = LoadedModel(x)
|
loaded_model = LoadedModel(x)
|
||||||
try:
|
try:
|
||||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||||
@@ -676,19 +706,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
model_to_unload.model.detach(unpatch_all=False)
|
model_to_unload.model.detach(unpatch_all=False)
|
||||||
model_to_unload.model_finalizer.detach()
|
model_to_unload.model_finalizer.detach()
|
||||||
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
|
total_ram_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||||
|
#want to do.
|
||||||
|
#FIXME: This should subtract off the to_load current pin consumption.
|
||||||
|
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_mem = get_free_memory(device)
|
free_mem = get_free_memory(device)
|
||||||
if free_mem < minimum_memory_required:
|
if free_mem < minimum_memory_required:
|
||||||
models_l = free_memory(minimum_memory_required, device)
|
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
|
||||||
logging.info("{} models unloaded.".format(len(models_l)))
|
logging.info("{} models unloaded.".format(len(models_l)))
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
@@ -732,6 +768,9 @@ def loaded_models(only_currently_used=False):
|
|||||||
|
|
||||||
def cleanup_models_gc():
|
def cleanup_models_gc():
|
||||||
do_gc = False
|
do_gc = False
|
||||||
|
|
||||||
|
reset_cast_buffers()
|
||||||
|
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
cur = current_loaded_models[i]
|
cur = current_loaded_models[i]
|
||||||
if cur.is_dead():
|
if cur.is_dead():
|
||||||
@@ -749,6 +788,11 @@ def cleanup_models_gc():
|
|||||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||||
|
|
||||||
|
|
||||||
|
def archive_model_dtypes(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
|
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
to_delete = []
|
to_delete = []
|
||||||
@@ -792,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
|
|
||||||
mem_dev = get_free_memory(torch_dev)
|
mem_dev = get_free_memory(torch_dev)
|
||||||
mem_cpu = get_free_memory(cpu_dev)
|
mem_cpu = get_free_memory(cpu_dev)
|
||||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
|
||||||
return torch_dev
|
return torch_dev
|
||||||
else:
|
else:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
@@ -1051,6 +1095,50 @@ def current_stream(device):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
stream_counters = {}
|
stream_counters = {}
|
||||||
|
|
||||||
|
STREAM_CAST_BUFFERS = {}
|
||||||
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
|
|
||||||
|
def get_cast_buffer(offload_stream, device, size, ref):
|
||||||
|
global LARGEST_CASTED_WEIGHT
|
||||||
|
|
||||||
|
if offload_stream is not None:
|
||||||
|
wf_context = offload_stream
|
||||||
|
if hasattr(wf_context, "as_context"):
|
||||||
|
wf_context = wf_context.as_context(offload_stream)
|
||||||
|
else:
|
||||||
|
wf_context = nullcontext()
|
||||||
|
|
||||||
|
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
||||||
|
if cast_buffer is None or cast_buffer.numel() < size:
|
||||||
|
if ref is LARGEST_CASTED_WEIGHT[0]:
|
||||||
|
#If there is one giant weight we do not want both streams to
|
||||||
|
#allocate a buffer for it. It's up to the caster to get the other
|
||||||
|
#offload stream in this corner case
|
||||||
|
return None
|
||||||
|
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||||
|
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||||
|
synchronize()
|
||||||
|
del STREAM_CAST_BUFFERS[offload_stream]
|
||||||
|
del cast_buffer
|
||||||
|
soft_empty_cache()
|
||||||
|
with wf_context:
|
||||||
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||||
|
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
|
|
||||||
|
if size > LARGEST_CASTED_WEIGHT[1]:
|
||||||
|
LARGEST_CASTED_WEIGHT = (ref, size)
|
||||||
|
|
||||||
|
return cast_buffer
|
||||||
|
|
||||||
|
def reset_cast_buffers():
|
||||||
|
global LARGEST_CASTED_WEIGHT
|
||||||
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
|
for offload_stream in STREAM_CAST_BUFFERS:
|
||||||
|
offload_stream.synchronize()
|
||||||
|
STREAM_CAST_BUFFERS.clear()
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
if NUM_STREAMS == 0:
|
if NUM_STREAMS == 0:
|
||||||
@@ -1093,7 +1181,61 @@ def sync_stream(device, stream):
|
|||||||
return
|
return
|
||||||
current_stream(device).wait_stream(stream)
|
current_stream(device).wait_stream(stream)
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
|
||||||
|
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||||
|
wf_context = nullcontext()
|
||||||
|
if stream is not None:
|
||||||
|
wf_context = stream
|
||||||
|
if hasattr(wf_context, "as_context"):
|
||||||
|
wf_context = wf_context.as_context(stream)
|
||||||
|
|
||||||
|
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||||
|
with wf_context:
|
||||||
|
for tensor in tensors:
|
||||||
|
dest_view = dest_views.pop(0)
|
||||||
|
if tensor is None:
|
||||||
|
continue
|
||||||
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
|
if hasattr(weight, "_v"):
|
||||||
|
#Unexpected usage patterns. There is no reason these don't work but they
|
||||||
|
#have no testing and no callers do this.
|
||||||
|
assert r is None
|
||||||
|
assert stream is None
|
||||||
|
|
||||||
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||||
|
|
||||||
|
if dtype is None:
|
||||||
|
dtype = weight._model_dtype
|
||||||
|
|
||||||
|
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||||
|
if signature is not None:
|
||||||
|
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||||
|
v_tensor = weight._v_tensor
|
||||||
|
else:
|
||||||
|
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||||
|
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||||
|
weight._v_tensor = v_tensor
|
||||||
|
weight._v_signature = signature
|
||||||
|
#Send it over
|
||||||
|
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return v_tensor.to(dtype=dtype)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||||
|
#Offloaded casting could skip this, however it would make the quantizations
|
||||||
|
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||||
|
#that would happen in regular flow to make offload deterministic.
|
||||||
|
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||||
|
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||||
|
weight = cast_buffer
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
@@ -1112,10 +1254,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
if hasattr(wf_context, "as_context"):
|
if hasattr(wf_context, "as_context"):
|
||||||
wf_context = wf_context.as_context(stream)
|
wf_context = wf_context.as_context(stream)
|
||||||
with wf_context:
|
with wf_context:
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
if r is None:
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
if r is None:
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@@ -1135,14 +1279,14 @@ if not args.disable_pinned_memory:
|
|||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||||
|
|
||||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||||
|
|
||||||
def discard_cuda_async_error():
|
def discard_cuda_async_error():
|
||||||
try:
|
try:
|
||||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
_ = a + b
|
_ = a + b
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
except torch.AcceleratorError:
|
except torch.AcceleratorError:
|
||||||
#Dump it! We already know about it from the synchronous return
|
#Dump it! We already know about it from the synchronous return
|
||||||
pass
|
pass
|
||||||
@@ -1546,6 +1690,12 @@ def lora_compute_dtype(device):
|
|||||||
LORA_COMPUTE_DTYPES[device] = dtype
|
LORA_COMPUTE_DTYPES[device] = dtype
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
def synchronize():
|
||||||
|
if is_intel_xpu():
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
@@ -1557,6 +1707,7 @@ def soft_empty_cache(force=False):
|
|||||||
elif is_mlu():
|
elif is_mlu():
|
||||||
torch.mlu.empty_cache()
|
torch.mlu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
@@ -1568,9 +1719,6 @@ def debug_memory_summary():
|
|||||||
return torch.cuda.memory.memory_summary()
|
return torch.cuda.memory.memory_summary()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
|
||||||
import threading
|
|
||||||
|
|
||||||
class InterruptProcessingException(Exception):
|
class InterruptProcessingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -38,19 +37,7 @@ from comfy.comfy_types import UnetWrapperFunction
|
|||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
|
||||||
|
import comfy_aimdo.model_vbar
|
||||||
def string_to_seed(data):
|
|
||||||
crc = 0xFFFFFFFF
|
|
||||||
for byte in data:
|
|
||||||
if isinstance(byte, str):
|
|
||||||
byte = ord(byte)
|
|
||||||
crc ^= byte
|
|
||||||
for _ in range(8):
|
|
||||||
if crc & 1:
|
|
||||||
crc = (crc >> 1) ^ 0xEDB88320
|
|
||||||
else:
|
|
||||||
crc >>= 1
|
|
||||||
return crc ^ 0xFFFFFFFF
|
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@@ -123,6 +110,10 @@ def move_weight_functions(m, device):
|
|||||||
memory += f.move_to(device=device)
|
memory += f.move_to(device=device)
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
|
def string_to_seed(data):
|
||||||
|
logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils")
|
||||||
|
return comfy.utils.string_to_seed(data)
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||||
self.key = key
|
self.key = key
|
||||||
@@ -169,6 +160,11 @@ def get_key_weight(model, key):
|
|||||||
|
|
||||||
return weight, set_func, convert_func
|
return weight, set_func, convert_func
|
||||||
|
|
||||||
|
def key_param_name_to_key(key, param):
|
||||||
|
if len(key) == 0:
|
||||||
|
return param
|
||||||
|
return "{}.{}".format(key, param)
|
||||||
|
|
||||||
class AutoPatcherEjector:
|
class AutoPatcherEjector:
|
||||||
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -212,6 +208,27 @@ class MemoryCounter:
|
|||||||
def decrement(self, used: int):
|
def decrement(self, used: int):
|
||||||
self.value -= used
|
self.value -= used
|
||||||
|
|
||||||
|
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
|
||||||
|
|
||||||
|
class LazyCastingParam(torch.nn.Parameter):
|
||||||
|
def __new__(cls, model, key, tensor):
|
||||||
|
return super().__new__(cls, tensor)
|
||||||
|
|
||||||
|
def __init__(self, model, key, tensor):
|
||||||
|
self.model = model
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return CustomTorchDevice
|
||||||
|
|
||||||
|
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
|
||||||
|
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
|
||||||
|
#all weights getting garbage collected per-weight
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@@ -269,6 +286,9 @@ class ModelPatcher:
|
|||||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||||
self.model.model_offload_buffer_memory = 0
|
self.model.model_offload_buffer_memory = 0
|
||||||
|
|
||||||
|
def is_dynamic(self):
|
||||||
|
return False
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
@@ -284,6 +304,9 @@ class ModelPatcher:
|
|||||||
def lowvram_patch_counter(self):
|
def lowvram_patch_counter(self):
|
||||||
return self.model.lowvram_patch_counter
|
return self.model.lowvram_patch_counter
|
||||||
|
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
return comfy.model_management.get_free_memory(device)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
@@ -293,7 +316,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
@@ -383,13 +406,16 @@ class ModelPatcher:
|
|||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
|
def disable_model_cfg1_optimization(self):
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
if disable_cfg1_optimization:
|
if disable_cfg1_optimization:
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
self.disable_model_cfg1_optimization()
|
||||||
|
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
@@ -611,14 +637,14 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
|
||||||
if key not in self.patches:
|
|
||||||
return
|
|
||||||
|
|
||||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
if key not in self.patches:
|
||||||
|
return weight
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup and not return_weight:
|
||||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||||
@@ -631,13 +657,15 @@ class ModelPatcher:
|
|||||||
|
|
||||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
if set_func is None:
|
if set_func is None:
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||||
if inplace_update:
|
if return_weight:
|
||||||
|
return out_weight
|
||||||
|
elif inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
|
||||||
|
|
||||||
def pin_weight_to_device(self, key):
|
def pin_weight_to_device(self, key):
|
||||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
@@ -654,18 +682,19 @@ class ModelPatcher:
|
|||||||
for key in list(self.pinned):
|
for key in list(self.pinned):
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
params = []
|
default = False
|
||||||
skip = False
|
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||||
for name, param in m.named_parameters(recurse=False):
|
|
||||||
params.append(name)
|
|
||||||
for name, param in m.named_parameters(recurse=True):
|
for name, param in m.named_parameters(recurse=True):
|
||||||
if name not in params:
|
if name not in params:
|
||||||
skip = True # skip random weights in non leaf modules
|
default = True # default random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if default and default_device is not None:
|
||||||
|
for param in params.values():
|
||||||
|
param.data = param.data.to(device=default_device)
|
||||||
|
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@@ -681,7 +710,8 @@ class ModelPatcher:
|
|||||||
return 0
|
return 0
|
||||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||||
|
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
@@ -773,7 +803,7 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
key = "{}.{}".format(n, param)
|
key = key_param_name_to_key(n, param)
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
if comfy.model_management.is_device_cuda(device_to):
|
if comfy.model_management.is_device_cuda(device_to):
|
||||||
@@ -789,7 +819,7 @@ class ModelPatcher:
|
|||||||
n = x[1]
|
n = x[1]
|
||||||
params = x[3]
|
params = x[3]
|
||||||
for param in params:
|
for param in params:
|
||||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||||
|
|
||||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
@@ -895,7 +925,7 @@ class ModelPatcher:
|
|||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
move_weight = True
|
move_weight = True
|
||||||
for param in params:
|
for param in params:
|
||||||
key = "{}.{}".format(n, param)
|
key = key_param_name_to_key(n, param)
|
||||||
bk = self.backup.get(key, None)
|
bk = self.backup.get(key, None)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
if not lowvram_possible:
|
if not lowvram_possible:
|
||||||
@@ -946,7 +976,7 @@ class ModelPatcher:
|
|||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||||
|
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
@@ -984,6 +1014,9 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
|
pass
|
||||||
|
|
||||||
def detach(self, unpatch_all=True):
|
def detach(self, unpatch_all=True):
|
||||||
self.eject_model()
|
self.eject_model()
|
||||||
self.model_patches_to(self.offload_device)
|
self.model_patches_to(self.offload_device)
|
||||||
@@ -1317,10 +1350,10 @@ class ModelPatcher:
|
|||||||
key, original_weights=original_weights)
|
key, original_weights=original_weights)
|
||||||
del original_weights[key]
|
del original_weights[key]
|
||||||
if set_func is None:
|
if set_func is None:
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
|
||||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||||
# TODO: disable caching if not enough system RAM to do so
|
# TODO: disable caching if not enough system RAM to do so
|
||||||
target_device = self.offload_device
|
target_device = self.offload_device
|
||||||
@@ -1355,7 +1388,275 @@ class ModelPatcher:
|
|||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
self.clear_cached_hook_weights()
|
self.clear_cached_hook_weights()
|
||||||
|
|
||||||
|
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||||
|
unet_state_dict = self.model.diffusion_model.state_dict()
|
||||||
|
for k, v in unet_state_dict.items():
|
||||||
|
op_keys = k.rsplit('.', 1)
|
||||||
|
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||||
|
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||||
|
continue
|
||||||
|
key = "diffusion_model." + k
|
||||||
|
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||||
|
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.unpin_all_weights()
|
self.unpin_all_weights()
|
||||||
self.detach(unpatch_all=False)
|
self.detach(unpatch_all=False)
|
||||||
|
|
||||||
|
class ModelPatcherDynamic(ModelPatcher):
|
||||||
|
|
||||||
|
def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False):
|
||||||
|
if load_device is not None and comfy.model_management.is_device_cpu(load_device):
|
||||||
|
#reroute to default MP for CPUs
|
||||||
|
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
|
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||||
|
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||||
|
#and non-dynamic patchers.
|
||||||
|
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||||
|
del self.model.model_loaded_weight_memory
|
||||||
|
if not hasattr(self.model, "dynamic_vbars"):
|
||||||
|
self.model.dynamic_vbars = {}
|
||||||
|
assert load_device is not None
|
||||||
|
|
||||||
|
def is_dynamic(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _vbar_get(self, create=False):
|
||||||
|
if self.load_device == torch.device("cpu"):
|
||||||
|
return None
|
||||||
|
vbar = self.model.dynamic_vbars.get(self.load_device, None)
|
||||||
|
if create and vbar is None:
|
||||||
|
# x10. We dont know what model defined type casts we have in the vbar, but virtual address
|
||||||
|
# space is pretty free. This will cover someone casting an entire model from FP4 to FP32
|
||||||
|
# with some left over.
|
||||||
|
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index)
|
||||||
|
self.model.dynamic_vbars[self.load_device] = vbar
|
||||||
|
return vbar
|
||||||
|
|
||||||
|
def loaded_size(self):
|
||||||
|
vbar = self._vbar_get()
|
||||||
|
if vbar is None:
|
||||||
|
return 0
|
||||||
|
return vbar.loaded_size()
|
||||||
|
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||||
|
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||||
|
#would all be avaiable for inference use.
|
||||||
|
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||||
|
|
||||||
|
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||||
|
|
||||||
|
def pin_weight_to_device(self, key):
|
||||||
|
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
|
||||||
|
|
||||||
|
def unpin_weight(self, key):
|
||||||
|
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
|
||||||
|
|
||||||
|
def unpin_all_weights(self):
|
||||||
|
self.partially_unload_ram(1e32)
|
||||||
|
|
||||||
|
def memory_required(self, input_shape):
|
||||||
|
#Pad this significantly. We are trying to get away from precise estimates. This
|
||||||
|
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
|
||||||
|
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||||
|
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||||
|
|
||||||
|
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
|
||||||
|
#doesn't need to be forced at this stage. The only thing you could do would be patch
|
||||||
|
#it all on CPU which consumes huge RAM.
|
||||||
|
assert not force_patch_weights
|
||||||
|
|
||||||
|
#Full load doesn't make sense as we dont actually have any loader capability here and
|
||||||
|
#now.
|
||||||
|
assert not full_load
|
||||||
|
|
||||||
|
assert device_to == self.load_device
|
||||||
|
|
||||||
|
num_patches = 0
|
||||||
|
allocated_size = 0
|
||||||
|
|
||||||
|
with self.use_ejected():
|
||||||
|
self.unpatch_hooks()
|
||||||
|
|
||||||
|
vbar = self._vbar_get(create=True)
|
||||||
|
if vbar is not None:
|
||||||
|
vbar.prioritize()
|
||||||
|
|
||||||
|
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||||
|
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||||
|
#with a high layer rate (e.g. autoregressive LLMs).
|
||||||
|
#prioritize the non-comfy weights (note the order reverse).
|
||||||
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
|
for x in loading:
|
||||||
|
_, _, _, n, m, params = x
|
||||||
|
|
||||||
|
def set_dirty(item, dirty):
|
||||||
|
if dirty or not hasattr(item, "_v_signature"):
|
||||||
|
item._v_signature = None
|
||||||
|
|
||||||
|
def setup_param(self, m, n, param_key):
|
||||||
|
nonlocal num_patches
|
||||||
|
key = key_param_name_to_key(n, param_key)
|
||||||
|
|
||||||
|
weight_function = []
|
||||||
|
|
||||||
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
|
if weight is None:
|
||||||
|
return (False, 0)
|
||||||
|
if key in self.patches:
|
||||||
|
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||||
|
return (True, 0)
|
||||||
|
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||||
|
num_patches += 1
|
||||||
|
else:
|
||||||
|
setattr(m, param_key + "_lowvram_function", None)
|
||||||
|
|
||||||
|
if key in self.weight_wrapper_patches:
|
||||||
|
weight_function.extend(self.weight_wrapper_patches[key])
|
||||||
|
setattr(m, param_key + "_function", weight_function)
|
||||||
|
geometry = weight
|
||||||
|
if not isinstance(weight, QuantizedTensor):
|
||||||
|
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||||
|
weight._model_dtype = model_dtype
|
||||||
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
|
return (False, comfy.memory_management.vram_aligned_size(geometry))
|
||||||
|
|
||||||
|
def force_load_param(self, param_key, device_to):
|
||||||
|
key = key_param_name_to_key(n, param_key)
|
||||||
|
if key in self.backup:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
m.pin_failed = False
|
||||||
|
m.seed_key = n
|
||||||
|
set_dirty(m, dirty)
|
||||||
|
|
||||||
|
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||||
|
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||||
|
force_load = force_load or force_load_bias
|
||||||
|
v_weight_size += v_weight_bias
|
||||||
|
|
||||||
|
if force_load:
|
||||||
|
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||||
|
force_load_param(self, "weight", device_to)
|
||||||
|
force_load_param(self, "bias", device_to)
|
||||||
|
else:
|
||||||
|
if vbar is not None and not hasattr(m, "_v"):
|
||||||
|
m._v = vbar.alloc(v_weight_size)
|
||||||
|
allocated_size += v_weight_size
|
||||||
|
|
||||||
|
else:
|
||||||
|
for param in params:
|
||||||
|
key = key_param_name_to_key(n, param)
|
||||||
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
|
weight.seed_key = key
|
||||||
|
set_dirty(weight, dirty)
|
||||||
|
geometry = weight
|
||||||
|
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||||
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
|
weight_size = geometry.numel() * geometry.element_size()
|
||||||
|
if vbar is not None and not hasattr(weight, "_v"):
|
||||||
|
weight._v = vbar.alloc(weight_size)
|
||||||
|
weight._model_dtype = model_dtype
|
||||||
|
allocated_size += weight_size
|
||||||
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
|
|
||||||
|
self.model.device = device_to
|
||||||
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
|
|
||||||
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||||
|
#These are all super dangerous. Who knows what the custom nodes actually do here...
|
||||||
|
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||||
|
|
||||||
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||||
|
|
||||||
|
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||||
|
assert not force_patch_weights #See above
|
||||||
|
assert self.load_device != torch.device("cpu")
|
||||||
|
|
||||||
|
vbar = self._vbar_get()
|
||||||
|
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
|
for x in loading:
|
||||||
|
_, _, _, _, m, _ = x
|
||||||
|
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||||
|
if ram_to_unload <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
|
#This isn't used by the core at all and can only be to load a model out of
|
||||||
|
#the control of proper model_managment. If you are a custom node author reading
|
||||||
|
#this, the correct pattern is to call load_models_gpu() to get a proper
|
||||||
|
#managed load of your model.
|
||||||
|
assert not load_weights
|
||||||
|
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||||
|
|
||||||
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
|
super().unpatch_model(device_to=None, unpatch_weights=False)
|
||||||
|
|
||||||
|
if unpatch_weights:
|
||||||
|
self.partially_unload_ram(1e32)
|
||||||
|
self.partially_unload(None, 1e32)
|
||||||
|
for m in self.model.modules():
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
|
keys = list(self.backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
bk = self.backup[k]
|
||||||
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
|
assert not force_patch_weights #See above
|
||||||
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||||
|
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
|
||||||
|
|
||||||
|
self.unpatch_model(self.offload_device, unpatch_weights=False)
|
||||||
|
self.patch_model(load_weights=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.load(device_to, dirty=dirty)
|
||||||
|
except Exception as e:
|
||||||
|
self.detach()
|
||||||
|
raise e
|
||||||
|
#ModelPatcher::partially_load returns a number on what got loaded but
|
||||||
|
#nothing in core uses this and we have no data in the Dynamic world. Hit
|
||||||
|
#the custom node devs with a None rather than a 0 that would mislead any
|
||||||
|
#logic they might have.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||||
|
assert False #Should be unreachable - we dont ever cache in the new implementation
|
||||||
|
|
||||||
|
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||||
|
if key not in combined_patches:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
|
||||||
|
|
||||||
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
CoreModelPatcher = ModelPatcher
|
||||||
|
|||||||
235
comfy/ops.py
235
comfy/ops.py
@@ -19,10 +19,15 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
|
||||||
import json
|
import json
|
||||||
|
import comfy.memory_management
|
||||||
|
import comfy.pinned_memory
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
import comfy_aimdo.model_vbar
|
||||||
|
import comfy_aimdo.torch
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
@@ -48,6 +53,8 @@ try:
|
|||||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||||
|
|
||||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -72,7 +79,122 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||||
|
offload_stream = None
|
||||||
|
xfer_dest = None
|
||||||
|
|
||||||
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
|
if signature is not None:
|
||||||
|
if resident:
|
||||||
|
weight = s._v_weight
|
||||||
|
bias = s._v_bias
|
||||||
|
else:
|
||||||
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||||
|
|
||||||
|
if not resident:
|
||||||
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
|
cast_dest = None
|
||||||
|
|
||||||
|
xfer_source = [ s.weight, s.bias ]
|
||||||
|
|
||||||
|
pin = comfy.pinned_memory.get_pin(s)
|
||||||
|
if pin is not None:
|
||||||
|
xfer_source = [ pin ]
|
||||||
|
|
||||||
|
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
if data.dtype != geometry.dtype:
|
||||||
|
cast_dest = xfer_dest
|
||||||
|
if cast_dest is None:
|
||||||
|
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||||
|
xfer_dest = None
|
||||||
|
break
|
||||||
|
|
||||||
|
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
if xfer_dest is None and offload_stream is not None:
|
||||||
|
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||||
|
if xfer_dest is None:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||||
|
if xfer_dest is None:
|
||||||
|
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||||
|
offload_stream = None
|
||||||
|
|
||||||
|
if signature is None and pin is None:
|
||||||
|
comfy.pinned_memory.pin_memory(s)
|
||||||
|
pin = comfy.pinned_memory.get_pin(s)
|
||||||
|
else:
|
||||||
|
pin = None
|
||||||
|
|
||||||
|
if pin is not None:
|
||||||
|
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||||
|
xfer_source = [ pin ]
|
||||||
|
#send it over
|
||||||
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
if cast_dest is not None:
|
||||||
|
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||||
|
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||||
|
if post_cast is not None:
|
||||||
|
post_cast.copy_(pre_cast)
|
||||||
|
xfer_dest = cast_dest
|
||||||
|
|
||||||
|
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||||
|
weight = params[0]
|
||||||
|
bias = params[1]
|
||||||
|
if signature is not None:
|
||||||
|
s._v_weight = weight
|
||||||
|
s._v_bias = bias
|
||||||
|
s._v_signature=signature
|
||||||
|
|
||||||
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
fns = getattr(s, param_key + "_function", [])
|
||||||
|
|
||||||
|
orig = x
|
||||||
|
|
||||||
|
def to_dequant(tensor, dtype):
|
||||||
|
tensor = tensor.to(dtype=dtype)
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
tensor = tensor.dequantize()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
if orig.dtype != dtype or len(fns) > 0:
|
||||||
|
x = to_dequant(x, dtype)
|
||||||
|
if not resident and lowvram_fn is not None:
|
||||||
|
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||||
|
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||||
|
x = lowvram_fn(x)
|
||||||
|
if (isinstance(orig, QuantizedTensor) and
|
||||||
|
(want_requant and len(fns) == 0 or update_weight)):
|
||||||
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||||
|
if want_requant and len(fns) == 0:
|
||||||
|
#The layer actually wants our freshly saved QT
|
||||||
|
x = y
|
||||||
|
elif update_weight:
|
||||||
|
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||||
|
if update_weight:
|
||||||
|
orig.copy_(y)
|
||||||
|
for f in fns:
|
||||||
|
x = f(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
update_weight = signature is not None
|
||||||
|
|
||||||
|
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||||
|
|
||||||
|
#FIXME: weird offload return protocol
|
||||||
|
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
@@ -87,22 +209,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
|
if hasattr(s, "_v"):
|
||||||
|
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
else:
|
else:
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
bias = None
|
||||||
|
weight = None
|
||||||
|
|
||||||
|
if offload_stream is not None and not args.cuda_malloc:
|
||||||
|
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
||||||
|
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||||
|
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
||||||
|
if cast_buffer is None:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||||
|
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
||||||
|
weight = params[0]
|
||||||
|
bias = params[1]
|
||||||
|
|
||||||
weight_has_function = len(s.weight_function) > 0
|
weight_has_function = len(s.weight_function) > 0
|
||||||
bias_has_function = len(s.bias_function) > 0
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
|
||||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
||||||
|
|
||||||
bias = None
|
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
@@ -110,6 +248,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
weight_a = weight
|
weight_a = weight
|
||||||
|
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
|
bias = bias.to(dtype=bias_dtype)
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
@@ -131,14 +270,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
|||||||
if offload_stream is None:
|
if offload_stream is None:
|
||||||
return
|
return
|
||||||
os, weight_a, bias_a = offload_stream
|
os, weight_a, bias_a = offload_stream
|
||||||
|
device=None
|
||||||
|
#FIXME: This is not good RTTI
|
||||||
|
if not isinstance(weight_a, torch.Tensor):
|
||||||
|
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||||
|
device = weight_a
|
||||||
if os is None:
|
if os is None:
|
||||||
return
|
return
|
||||||
if weight_a is not None:
|
if device is None:
|
||||||
device = weight_a.device
|
if weight_a is not None:
|
||||||
else:
|
device = weight_a.device
|
||||||
if bias_a is None:
|
else:
|
||||||
return
|
if bias_a is None:
|
||||||
device = bias_a.device
|
return
|
||||||
|
device = bias_a.device
|
||||||
os.wait_stream(comfy.model_management.current_stream(device))
|
os.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
|
||||||
|
|
||||||
@@ -149,6 +294,57 @@ class CastWeightBiasOp:
|
|||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
|
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||||
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Issue is with `torch.empty` still reserving the full memory for the layer.
|
||||||
|
# Windows doesn't over-commit memory so without this, We are momentarily commit
|
||||||
|
# charged for the weight even though we might zero-copy it when we load the
|
||||||
|
# state dict. If the commit charge exceeds the ceiling we can destabilize the
|
||||||
|
# system.
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
|
self.comfy_need_lazy_init_bias=bias
|
||||||
|
self.weight_comfy_model_dtype = dtype
|
||||||
|
self.bias_comfy_model_dtype = dtype
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
|
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||||
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
|
prefix_len = len(prefix)
|
||||||
|
for k,v in state_dict.items():
|
||||||
|
if k[prefix_len:] == "weight":
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
elif k[prefix_len:] == "bias" and v is not None:
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
else:
|
||||||
|
unexpected_keys.append(k)
|
||||||
|
|
||||||
|
#Reconcile default construction of the weight if its missing.
|
||||||
|
if self.weight is None:
|
||||||
|
v = torch.zeros(self.in_features, self.out_features)
|
||||||
|
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
missing_keys.append(prefix+"weight")
|
||||||
|
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||||
|
v = torch.zeros(self.out_features,)
|
||||||
|
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
missing_keys.append(prefix+"bias")
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -266,7 +462,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@@ -278,8 +474,7 @@ class disable_weight_init:
|
|||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -655,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||||
x = self._forward(input, weight, bias)
|
x = self._forward(input, weight, bias)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
@@ -666,6 +861,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
input_shape = input.shape
|
input_shape = input.shape
|
||||||
reshaped_3d = False
|
reshaped_3d = False
|
||||||
|
#If cast needs to apply lora, it should be done in the compute dtype
|
||||||
|
compute_dtype = input.dtype
|
||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||||
@@ -684,7 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||||
|
|
||||||
output = self.forward_comfy_cast_weights(input)
|
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||||
|
|
||||||
# Reshape output back to 3D if input was 3D
|
# Reshape output back to 3D if input was 3D
|
||||||
if reshaped_3d:
|
if reshaped_3d:
|
||||||
|
|||||||
29
comfy/pinned_memory.py
Normal file
29
comfy/pinned_memory.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.memory_management
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
def get_pin(module):
|
||||||
|
return getattr(module, "_pin", None)
|
||||||
|
|
||||||
|
def pin_memory(module):
|
||||||
|
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||||
|
return
|
||||||
|
#FIXME: This is a RAM cache trigger event
|
||||||
|
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||||
|
pin = torch.empty((size,), dtype=torch.uint8)
|
||||||
|
if comfy.model_management.pin_memory(pin):
|
||||||
|
module._pin = pin
|
||||||
|
else:
|
||||||
|
module.pin_failed = True
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def unpin_memory(module):
|
||||||
|
if get_pin(module) is None:
|
||||||
|
return 0
|
||||||
|
size = module._pin.numel() * module._pin.element_size()
|
||||||
|
comfy.model_management.unpin_memory(module._pin)
|
||||||
|
del module._pin
|
||||||
|
return size
|
||||||
@@ -1,57 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import numbers
|
|
||||||
import logging
|
|
||||||
|
|
||||||
RMSNorm = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
RMSNorm = torch.nn.RMSNorm
|
|
||||||
except:
|
|
||||||
rms_norm_torch = None
|
|
||||||
logging.warning("Please update pytorch to use native RMSNorm")
|
|
||||||
|
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if weight is None:
|
||||||
if weight is None:
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
else:
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
if RMSNorm is None:
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
normalized_shape,
|
|
||||||
eps=1e-6,
|
|
||||||
elementwise_affine=True,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
if isinstance(normalized_shape, numbers.Integral):
|
|
||||||
# mypy error: incompatible types in assignment
|
|
||||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
|
||||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
|
||||||
self.eps = eps
|
|
||||||
self.elementwise_affine = elementwise_affine
|
|
||||||
if self.elementwise_affine:
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("weight", None)
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return rms_norm(x, self.weight, self.eps)
|
|
||||||
|
|||||||
@@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
memory_required = 1e20
|
||||||
|
minimum_memory_required = None
|
||||||
|
else:
|
||||||
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
|
memory_required += inference_memory
|
||||||
|
minimum_memory_required += inference_memory
|
||||||
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
|||||||
import torch
|
import torch
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import collections
|
import collections
|
||||||
from comfy import model_management
|
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
@@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
to_batch_temp.reverse()
|
to_batch_temp.reverse()
|
||||||
to_batch = to_batch_temp[:1]
|
to_batch = to_batch_temp[:1]
|
||||||
|
|
||||||
free_memory = model_management.get_free_memory(x_in.device)
|
free_memory = model.current_patcher.get_free_memory(x_in.device)
|
||||||
for i in range(1, len(to_batch_temp) + 1):
|
for i in range(1, len(to_batch_temp) + 1):
|
||||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user