mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 19:20:03 +00:00
Compare commits
480 Commits
js/drafts/
...
asset-mana
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
917177e821 | ||
|
|
fd6ac0a765 | ||
|
|
94941c50b3 | ||
|
|
fbba2e59e5 | ||
|
|
adccfb2dfd | ||
|
|
9f4c0f3afe | ||
|
|
196954ab8c | ||
|
|
1e098d6132 | ||
|
|
cd66d72b46 | ||
|
|
2103e39335 | ||
|
|
d20576e6a3 | ||
|
|
a061b06321 | ||
|
|
80718908a9 | ||
|
|
7ea173c187 | ||
|
|
76eb1d72c3 | ||
|
|
c4a46e943c | ||
|
|
2b7f9a8196 | ||
|
|
ce4cb2389c | ||
|
|
ca39552954 | ||
|
|
c8d2117f02 | ||
|
|
fccab99ec0 | ||
|
|
fd79d32f38 | ||
|
|
341b4adefd | ||
|
|
b8730510db | ||
|
|
e808790799 | ||
|
|
145b0e4f79 | ||
|
|
707b2638ec | ||
|
|
8a5ac527e6 | ||
|
|
e3206351b0 | ||
|
|
1fee8827cb | ||
|
|
27bc181c49 | ||
|
|
d1d9eb94b1 | ||
|
|
7be2b49b6b | ||
|
|
9ed3c5cc09 | ||
|
|
66241cef31 | ||
|
|
e8df53b764 | ||
|
|
852704c81a | ||
|
|
9fdf8c25ab | ||
|
|
dc95b6acc0 | ||
|
|
711bcf33ee | ||
|
|
24b0fce099 | ||
|
|
1ea8c54064 | ||
|
|
8d6653fca6 | ||
|
|
4dd843d36f | ||
|
|
46fdd636de | ||
|
|
283cd27bdc | ||
|
|
dd611a7700 | ||
|
|
1a37d1476d | ||
|
|
f9602457d6 | ||
|
|
85ef08449d | ||
|
|
5b6810a2c6 | ||
|
|
621faaa195 | ||
|
|
9288c78fc5 | ||
|
|
e42682b24e | ||
|
|
d0aa64d57b | ||
|
|
677a0e2508 | ||
|
|
31ec744317 | ||
|
|
a336c7c165 | ||
|
|
77332d3054 | ||
|
|
24a95f5ca4 | ||
|
|
a39ac59c3e | ||
|
|
1a85483da1 | ||
|
|
47a9cde5d3 | ||
|
|
0be513b213 | ||
|
|
f1fb7432a0 | ||
|
|
f3cf99d10c | ||
|
|
5f187fe6fb | ||
|
|
025fc49b4e | ||
|
|
7becb84341 | ||
|
|
dda31de690 | ||
|
|
1d970382f0 | ||
|
|
a2fc2bbae4 | ||
|
|
a7f2546558 | ||
|
|
6cfa94ec58 | ||
|
|
a2ec1f7637 | ||
|
|
0b795dc7a7 | ||
|
|
47f7c7ee8c | ||
|
|
cdd8d16075 | ||
|
|
37b81e6658 | ||
|
|
4f1f26ac6c | ||
|
|
975650060f | ||
|
|
4a713654cd | ||
|
|
f228367c5e | ||
|
|
80b7c9455b | ||
|
|
c1297f4eb3 | ||
|
|
e5e70636e7 | ||
|
|
9b8e88ba6e | ||
|
|
bb9ed04758 | ||
|
|
29bf807b0e | ||
|
|
2559dee492 | ||
|
|
a3b04de700 | ||
|
|
d7f40442f9 | ||
|
|
b149e2e1e3 | ||
|
|
581bae2af3 | ||
|
|
af99928f22 | ||
|
|
53c9c7d39a | ||
|
|
ba68e83f1c | ||
|
|
dcb8834983 | ||
|
|
f9d2e4b742 | ||
|
|
45bc1f5c00 | ||
|
|
0aa074a420 | ||
|
|
7757d5a657 | ||
|
|
e600520f8a | ||
|
|
fd2b820ec2 | ||
|
|
934377ac1e | ||
|
|
d6b977b2e6 | ||
|
|
15ec9ea958 | ||
|
|
33bd9ed9cb | ||
|
|
18de0b2830 | ||
|
|
df6850fae8 | ||
|
|
e01e99d075 | ||
|
|
72212fef66 | ||
|
|
df34f1549a | ||
|
|
9b0553809c | ||
|
|
8d7c930246 | ||
|
|
3c9bf39c20 | ||
|
|
0df1ccac6f | ||
|
|
de44b95db6 | ||
|
|
72548a8ac4 | ||
|
|
6eaed072c7 | ||
|
|
543888d3d8 | ||
|
|
70fc0425b3 | ||
|
|
85e34643f8 | ||
|
|
5c33872e2f | ||
|
|
206595f854 | ||
|
|
b288fb0db8 | ||
|
|
f73b176abd | ||
|
|
a9096f6c97 | ||
|
|
964de8a8ad | ||
|
|
1886f10e19 | ||
|
|
357193f7b5 | ||
|
|
0ef73e95fd | ||
|
|
faa1e4de17 | ||
|
|
dfb5703d40 | ||
|
|
103a12cb66 | ||
|
|
97652d26b8 | ||
|
|
bd1d9bcd5f | ||
|
|
0e9de2b7c9 | ||
|
|
e3311c9229 | ||
|
|
3fa0fc496c | ||
|
|
fb763d4333 | ||
|
|
6282d495ca | ||
|
|
b8ef9bb92c | ||
|
|
bcbd7884e3 | ||
|
|
27a0fcccc3 | ||
|
|
2d9be462d3 | ||
|
|
789a62ce35 | ||
|
|
ea6cdd2631 | ||
|
|
84384ca0b4 | ||
|
|
2ee7879a0b | ||
|
|
3493b9cb1f | ||
|
|
ce270ba090 | ||
|
|
c9ebe70072 | ||
|
|
261421e218 | ||
|
|
a9f1bb10a5 | ||
|
|
b0338e930b | ||
|
|
b71f9bcb71 | ||
|
|
72855db715 | ||
|
|
f48d05a2d1 | ||
|
|
4368d8f87f | ||
|
|
22da0a83e9 | ||
|
|
50333f1715 | ||
|
|
26d5b86da8 | ||
|
|
4f5812b937 | ||
|
|
1bcb469089 | ||
|
|
464ba1d614 | ||
|
|
e3018c2a5a | ||
|
|
3412d53b1d | ||
|
|
e2d1e5dad9 | ||
|
|
27e067ce50 | ||
|
|
9b15155972 | ||
|
|
32a627bf1f | ||
|
|
fe442fac2e | ||
|
|
d2c502e629 | ||
|
|
fea9ea8268 | ||
|
|
f949094b3c | ||
|
|
4449e14769 | ||
|
|
885015eecf | ||
|
|
bf8363ec87 | ||
|
|
a86aaa4301 | ||
|
|
2efb2cbc38 | ||
|
|
15aa9222c4 | ||
|
|
c7bb3e2bce | ||
|
|
e80a14ad50 | ||
|
|
d28b39d93d | ||
|
|
1c184c29eb | ||
|
|
edde0b5043 | ||
|
|
0063610177 | ||
|
|
ce0052c087 | ||
|
|
6b86be320a | ||
|
|
0eb821a7b6 | ||
|
|
4aa79dbf2c | ||
|
|
38f697d953 | ||
|
|
3aad339b63 | ||
|
|
491755325c | ||
|
|
496888fd68 | ||
|
|
b5ac6ed7ce | ||
|
|
bdf4ba24ce | ||
|
|
871e41aec6 | ||
|
|
eb7008a4d3 | ||
|
|
0379eff0b5 | ||
|
|
026b7f209c | ||
|
|
b20ba1f27c | ||
|
|
31a37686d0 | ||
|
|
7c1b0be496 | ||
|
|
88aee596a3 | ||
|
|
6a193ac557 | ||
|
|
47f4db3e84 | ||
|
|
6fade5da38 | ||
|
|
5352abc6d3 | ||
|
|
39aa06bd5d | ||
|
|
914c2a2973 | ||
|
|
e633a47ad1 | ||
|
|
a763cbd39d | ||
|
|
09dabf95bc | ||
|
|
f6b93d41a0 | ||
|
|
95ac7794b7 | ||
|
|
d7464e9e73 | ||
|
|
a82577f64a | ||
|
|
f2ea0bc22c | ||
|
|
0755e5320a | ||
|
|
8d46bec951 | ||
|
|
71ed4a399e | ||
|
|
3e316c6338 | ||
|
|
8be0d22ab7 | ||
|
|
5c1b5973ac | ||
|
|
f92307cd4c | ||
|
|
59eddda900 | ||
|
|
41048c69b4 | ||
|
|
fc247150fe | ||
|
|
fe31ad0276 | ||
|
|
ca4e96a8ae | ||
|
|
050c67323c | ||
|
|
497d41fb50 | ||
|
|
ff57793659 | ||
|
|
f7bd5e58dd | ||
|
|
7ed73d12d1 | ||
|
|
eb39019daa | ||
|
|
bab08f40d1 | ||
|
|
bc49106837 | ||
|
|
1b2de2642d | ||
|
|
9fa1036f60 | ||
|
|
0737b7e0d2 | ||
|
|
0963493a9c | ||
|
|
e73a9dbe30 | ||
|
|
fe01885acf | ||
|
|
7139d6d93f | ||
|
|
2f52e8f05f | ||
|
|
8d38ea3bbf | ||
|
|
5a8f502db5 | ||
|
|
7cd2c4bd6a | ||
|
|
dfa791eb4b | ||
|
|
bddd69618b | ||
|
|
54d8fdbed0 | ||
|
|
d844d8b13b | ||
|
|
07a927517c | ||
|
|
f16a70ba67 | ||
|
|
36b5127fd3 | ||
|
|
4977f203fa | ||
|
|
c708d0a433 | ||
|
|
bd2ab73976 | ||
|
|
da2efeaec6 | ||
|
|
7f3b9b16c6 | ||
|
|
d4e353a94e | ||
|
|
ed43784b0d | ||
|
|
0f2b8525bc | ||
|
|
20a84166d0 | ||
|
|
ed2e33c69a | ||
|
|
1702e6df16 | ||
|
|
c308a8840a | ||
|
|
027c63f63a | ||
|
|
e08ecfbd8a | ||
|
|
4e5c230f6a | ||
|
|
f0d5d0111f | ||
|
|
ad19a069f6 | ||
|
|
5d65d6753b | ||
|
|
deebee4ff6 | ||
|
|
fa570cbf59 | ||
|
|
644b23ac0b | ||
|
|
72fd4d22b6 | ||
|
|
e4f7ea105f | ||
|
|
c991a5da65 | ||
|
|
9df8792d4b | ||
|
|
3da5a07510 | ||
|
|
afa0a45206 | ||
|
|
615eb52049 | ||
|
|
d5c1954d5c | ||
|
|
e400f26c8f | ||
|
|
5ca8e2fac3 | ||
|
|
3294782d19 | ||
|
|
898d88e10e | ||
|
|
560d38f34c | ||
|
|
1aa089e0b6 | ||
|
|
e1d4f36d8d | ||
|
|
1e3ae1eed8 | ||
|
|
f032c1a50a | ||
|
|
f4231a80b1 | ||
|
|
3089936a2c | ||
|
|
2208aa616d | ||
|
|
629b173837 | ||
|
|
fa340add55 | ||
|
|
966f3a5206 | ||
|
|
0552de7c7d | ||
|
|
5828607ccf | ||
|
|
735bb4bdb1 | ||
|
|
cd679129e3 | ||
|
|
bf2a1b5b1e | ||
|
|
42974a448c | ||
|
|
05df2df489 | ||
|
|
37d620a6b8 | ||
|
|
32691b16f4 | ||
|
|
4c3e57b0ae | ||
|
|
9126c0cfe4 | ||
|
|
d8c51ba15a | ||
|
|
32a95bba8a | ||
|
|
da1ad9b516 | ||
|
|
d044a24398 | ||
|
|
5be6fd09ff | ||
|
|
f69609bbd6 | ||
|
|
c012400240 | ||
|
|
03895dea7c | ||
|
|
84f9759424 | ||
|
|
7991341e89 | ||
|
|
140ffc7fdc | ||
|
|
182f90b5ec | ||
|
|
d7062277a7 | ||
|
|
54cf14cbbb | ||
|
|
aebac22193 | ||
|
|
13aaa66ec2 | ||
|
|
5f582a9757 | ||
|
|
fbcc23945d | ||
|
|
3dfefc88d0 | ||
|
|
bff60b5cfc | ||
|
|
1e638a140b | ||
|
|
4696d74305 | ||
|
|
5ee381c058 | ||
|
|
4887743a2a | ||
|
|
97b8a2c26a | ||
|
|
97eb256a35 | ||
|
|
61b08d4ba6 | ||
|
|
da9dab7edd | ||
|
|
d2aaef029c | ||
|
|
0a3d062e06 | ||
|
|
2f74e17975 | ||
|
|
dca6bdd4fa | ||
|
|
7d593baf91 | ||
|
|
c60dc4177c | ||
|
|
5d4cc3ba1b | ||
|
|
9f1388c0a3 | ||
|
|
a88788dce6 | ||
|
|
d0210fe2e5 | ||
|
|
e6d9f62744 | ||
|
|
78672d0ee6 | ||
|
|
1ef70fcde4 | ||
|
|
0621d73a9c | ||
|
|
b850d9a8bb | ||
|
|
c60467a148 | ||
|
|
c0207b473f | ||
|
|
93bc2f8e4d | ||
|
|
e6e5d33b35 | ||
|
|
4293e4da21 | ||
|
|
69cb57b342 | ||
|
|
d03ae077b4 | ||
|
|
0ccc88b03f | ||
|
|
eb2f78b4e0 | ||
|
|
e729a5cc11 | ||
|
|
e78d230496 | ||
|
|
d3504e1778 | ||
|
|
a86a58c308 | ||
|
|
39dda1d40d | ||
|
|
5ad33787de | ||
|
|
255f139863 | ||
|
|
5ac9ec214b | ||
|
|
0aa1c58b04 | ||
|
|
5249e45a1c | ||
|
|
54a45b9967 | ||
|
|
9a470e073e | ||
|
|
7d627f764c | ||
|
|
a0c0785635 | ||
|
|
100c2478ea | ||
|
|
1da5639e86 | ||
|
|
1b96fae1d4 | ||
|
|
7f492522b6 | ||
|
|
650838fd6f | ||
|
|
491fafbd64 | ||
|
|
9bc2798f72 | ||
|
|
50afba747c | ||
|
|
6b8062f414 | ||
|
|
b1ae4126c3 | ||
|
|
9dabda19f0 | ||
|
|
543c24108c | ||
|
|
260a5ca5d9 | ||
|
|
861c3bbb3d | ||
|
|
9ca581c941 | ||
|
|
4831e9c2c4 | ||
|
|
480375f349 | ||
|
|
b40143984c | ||
|
|
b43916a134 | ||
|
|
7bc7dd2aa2 | ||
|
|
938d3e8216 | ||
|
|
8f05fb48ea | ||
|
|
b7ff5bd14d | ||
|
|
2b653e8c18 | ||
|
|
1fd306824d | ||
|
|
1205afc708 | ||
|
|
5612670ee4 | ||
|
|
181a9bf26d | ||
|
|
aac10ad23a | ||
|
|
974254218a | ||
|
|
c5de4955bb | ||
|
|
9fd0cd7cf7 | ||
|
|
b5e97db9ac | ||
|
|
1359c969e4 | ||
|
|
059cd38aa2 | ||
|
|
e740dfd806 | ||
|
|
7eab7d2944 | ||
|
|
75d327abd5 | ||
|
|
ee615ac269 | ||
|
|
27870ec3c3 | ||
|
|
f41f323c52 | ||
|
|
f74fc4d927 | ||
|
|
ae26cd99b5 | ||
|
|
e9af97ba1a | ||
|
|
d9277301d2 | ||
|
|
34c8eeec06 | ||
|
|
9f1069290c | ||
|
|
111f583e00 | ||
|
|
79ed752748 | ||
|
|
772de7c006 | ||
|
|
b22e97dcfa | ||
|
|
f02de13316 | ||
|
|
c46268bf60 | ||
|
|
cf49a2c5b5 | ||
|
|
170c7bb90c | ||
|
|
2a0b138feb | ||
|
|
e195c1b13f | ||
|
|
5b4eb021cb | ||
|
|
396454fa41 | ||
|
|
a3cf272522 | ||
|
|
ba9548f756 | ||
|
|
e18f53cca9 | ||
|
|
c36be0ea09 | ||
|
|
9093301a49 | ||
|
|
bd951a714f | ||
|
|
6493709d6a | ||
|
|
b976f934ae | ||
|
|
7d8cf4cacc | ||
|
|
68f4496b8e | ||
|
|
ef5266b1c1 | ||
|
|
a96e65df18 | ||
|
|
93a49a45de | ||
|
|
ec70ed6aea | ||
|
|
7a13f74220 | ||
|
|
8042eb20c6 | ||
|
|
bd9f166c12 | ||
|
|
dd94416db2 | ||
|
|
ae0e7c4dff | ||
|
|
78f79266a9 | ||
|
|
1883e70b43 | ||
|
|
31ca603ccb | ||
|
|
f7fb193712 | ||
|
|
7e9267fa77 | ||
|
|
91d40086db | ||
|
|
5b12b55e32 | ||
|
|
e9e9a031a8 | ||
|
|
d7430c529a | ||
|
|
cd88f709ab | ||
|
|
4459a17e82 | ||
|
|
483b3e62e0 | ||
|
|
8e81c507d2 | ||
|
|
e1c6dc720e | ||
|
|
7ea79ebb9d | ||
|
|
ae75a084df | ||
|
|
d6a2137fc3 | ||
|
|
53e8d8193c | ||
|
|
29596bd53f | ||
|
|
7d5160f92c | ||
|
|
7f7b3f1695 | ||
|
|
9da6aca0d0 | ||
|
|
1cb3c98947 |
@@ -4,6 +4,9 @@ if you have a NVIDIA gpu:
|
||||
|
||||
run_nvidia_gpu.bat
|
||||
|
||||
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
|
||||
|
||||
run_nvidia_gpu_fast_fp16_accumulation.bat
|
||||
|
||||
|
||||
To run it in slow CPU mode:
|
||||
|
||||
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,2 +1,3 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
||||
comfy_api_nodes/apis/__init__.py linguist-generated
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -22,7 +22,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@@ -18,7 +18,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
||||
40
.github/workflows/check-line-endings.yml
vendored
Normal file
40
.github/workflows/check-line-endings.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Check for Windows Line Endings
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*'] # Trigger on all pull requests to any branch
|
||||
|
||||
jobs:
|
||||
check-line-endings:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history to compare changes
|
||||
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
# Loop through each changed file
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# Check if the file exists and is a text file
|
||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
||||
# Check for CRLF line endings
|
||||
if grep -UP '\r$' "$FILE"; then
|
||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
||||
CRLF_FOUND=true
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Exit with error if CRLF was found
|
||||
if [ "$CRLF_FOUND" = true ]; then
|
||||
exit 1
|
||||
fi
|
||||
108
.github/workflows/release-webhook.yml
vendored
Normal file
108
.github/workflows/release-webhook.yml
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
name: Release Webhook
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||
run: |
|
||||
# Generate UUID for delivery ID
|
||||
DELIVERY_ID=$(uuidgen)
|
||||
HOOK_ID="release-webhook-$(date +%s)"
|
||||
|
||||
# Create webhook payload matching GitHub release webhook format
|
||||
PAYLOAD=$(cat <<EOF
|
||||
{
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": ${{ github.event.release.id }},
|
||||
"node_id": "${{ github.event.release.node_id }}",
|
||||
"url": "${{ github.event.release.url }}",
|
||||
"html_url": "${{ github.event.release.html_url }}",
|
||||
"assets_url": "${{ github.event.release.assets_url }}",
|
||||
"upload_url": "${{ github.event.release.upload_url }}",
|
||||
"tag_name": "${{ github.event.release.tag_name }}",
|
||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||
"name": ${{ toJSON(github.event.release.name) }},
|
||||
"body": ${{ toJSON(github.event.release.body) }},
|
||||
"draft": ${{ github.event.release.draft }},
|
||||
"prerelease": ${{ github.event.release.prerelease }},
|
||||
"created_at": "${{ github.event.release.created_at }}",
|
||||
"published_at": "${{ github.event.release.published_at }}",
|
||||
"author": {
|
||||
"login": "${{ github.event.release.author.login }}",
|
||||
"id": ${{ github.event.release.author.id }},
|
||||
"node_id": "${{ github.event.release.author.node_id }}",
|
||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||
"url": "${{ github.event.release.author.url }}",
|
||||
"html_url": "${{ github.event.release.author.html_url }}",
|
||||
"type": "${{ github.event.release.author.type }}",
|
||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||
},
|
||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||
},
|
||||
"repository": {
|
||||
"id": ${{ github.event.repository.id }},
|
||||
"node_id": "${{ github.event.repository.node_id }}",
|
||||
"name": "${{ github.event.repository.name }}",
|
||||
"full_name": "${{ github.event.repository.full_name }}",
|
||||
"private": ${{ github.event.repository.private }},
|
||||
"owner": {
|
||||
"login": "${{ github.event.repository.owner.login }}",
|
||||
"id": ${{ github.event.repository.owner.id }},
|
||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||
"url": "${{ github.event.repository.owner.url }}",
|
||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||
"type": "${{ github.event.repository.owner.type }}",
|
||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||
},
|
||||
"html_url": "${{ github.event.repository.html_url }}",
|
||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||
"git_url": "${{ github.event.repository.git_url }}",
|
||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||
"url": "${{ github.event.repository.url }}",
|
||||
"created_at": "${{ github.event.repository.created_at }}",
|
||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||
"fork": ${{ github.event.repository.fork }}
|
||||
},
|
||||
"sender": {
|
||||
"login": "${{ github.event.sender.login }}",
|
||||
"id": ${{ github.event.sender.id }},
|
||||
"node_id": "${{ github.event.sender.node_id }}",
|
||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||
"url": "${{ github.event.sender.url }}",
|
||||
"html_url": "${{ github.event.sender.html_url }}",
|
||||
"type": "${{ github.event.sender.type }}",
|
||||
"site_admin": ${{ github.event.sender.site_admin }}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Generate HMAC-SHA256 signature
|
||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||
|
||||
# Send webhook with required headers
|
||||
curl -X POST "$WEBHOOK_URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-GitHub-Event: release" \
|
||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||
-d "$PAYLOAD" \
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
20
.github/workflows/stable-release.yml
vendored
20
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -66,8 +66,13 @@ jobs:
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
@@ -85,7 +90,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
@@ -102,5 +107,4 @@ jobs:
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
prerelease: true
|
||||
make_latest: false
|
||||
draft: true
|
||||
|
||||
173
.github/workflows/test-assets.yml
vendored
Normal file
173
.github/workflows/test-assets.yml
vendored
Normal file
@@ -0,0 +1,173 @@
|
||||
name: Asset System Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'app/**'
|
||||
- 'tests-assets/**'
|
||||
- '.github/workflows/test-assets.yml'
|
||||
- 'requirements.txt'
|
||||
pull_request:
|
||||
branches: [master]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
PIP_DISABLE_PIP_VERSION_CHECK: '1'
|
||||
PYTHONUNBUFFERED: '1'
|
||||
|
||||
jobs:
|
||||
sqlite:
|
||||
name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 40
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python: ['3.9', '3.12']
|
||||
sqlite_mode: ['memory', 'file']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -U pip wheel
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-aiohttp pytest-asyncio
|
||||
|
||||
- name: Set deterministic test base dir
|
||||
id: basedir
|
||||
shell: bash
|
||||
run: |
|
||||
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
|
||||
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
|
||||
mkdir -p "$BASE/logs"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE"
|
||||
|
||||
- name: Set DB URL for SQLite
|
||||
id: setdb
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then
|
||||
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV"
|
||||
else
|
||||
DBFILE="$RUNNER_TEMP/assets-tests.sqlite"
|
||||
mkdir -p "$(dirname "$DBFILE")"
|
||||
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
- name: Run tests
|
||||
run: python -m pytest tests-assets
|
||||
|
||||
- name: Show ComfyUI logs
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
|
||||
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
|
||||
ls -la "$ASSETS_TEST_LOGS" || true
|
||||
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
|
||||
if [ -f "$f" ]; then
|
||||
echo "----- BEGIN $f -----"
|
||||
sed -n '1,400p' "$f"
|
||||
echo "----- END $f -----"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Upload ComfyUI logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }}
|
||||
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
|
||||
if-no-files-found: warn
|
||||
|
||||
postgres:
|
||||
name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 40
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python: ['3.9', '3.12']
|
||||
pgsql: ['16', '18']
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:${{ matrix.pgsql }}
|
||||
env:
|
||||
POSTGRES_DB: assets
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres -d assets"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 12
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -U pip wheel
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-aiohttp pytest-asyncio
|
||||
pip install greenlet psycopg
|
||||
|
||||
- name: Set deterministic test base dir
|
||||
id: basedir
|
||||
shell: bash
|
||||
run: |
|
||||
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
|
||||
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
|
||||
mkdir -p "$BASE/logs"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE"
|
||||
|
||||
- name: Set DB URL for PostgreSQL
|
||||
shell: bash
|
||||
run: |
|
||||
echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Run tests
|
||||
run: python -m pytest tests-assets
|
||||
|
||||
- name: Show ComfyUI logs
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
|
||||
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
|
||||
ls -la "$ASSETS_TEST_LOGS" || true
|
||||
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
|
||||
if [ -f "$f" ]; then
|
||||
echo "----- BEGIN $f -----"
|
||||
sed -n '1,400p' "$f"
|
||||
echo "----- END $f -----"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Upload ComfyUI logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }}
|
||||
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
|
||||
if-no-files-found: warn
|
||||
30
.github/workflows/test-execution.yml
vendored
Normal file
30
.github/workflows/test-execution.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
6
.github/workflows/test-unit.yml
vendored
6
.github/workflows/test-unit.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
os: [ubuntu-latest, windows-2022, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
@@ -28,7 +28,3 @@ jobs:
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
- name: Run Execution Model Tests
|
||||
run: |
|
||||
python -m pytest tests/inference/test_execution.py
|
||||
|
||||
|
||||
@@ -17,19 +17,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "2"
|
||||
default: "5"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -53,6 +53,8 @@ jobs:
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
|
||||
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -64,6 +64,10 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
@@ -82,7 +86,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
23
CODEOWNERS
23
CODEOWNERS
@@ -1,24 +1,3 @@
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
|
||||
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
* @kosinkadink
|
||||
|
||||
69
README.md
69
README.md
@@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
## Get Started
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
|
||||
#### [Windows Portable Package](#installing)
|
||||
@@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Image Models
|
||||
- SD1.x, SD2.x,
|
||||
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||
@@ -65,13 +65,20 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@@ -79,9 +86,10 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
|
||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
|
||||
- Safe loading of ckpt, pt, pth, etc.. files.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
@@ -92,12 +100,10 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
@@ -106,7 +112,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
@@ -174,10 +180,6 @@ If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
## Jupyter Notebook
|
||||
|
||||
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
|
||||
|
||||
|
||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
||||
|
||||
@@ -189,7 +191,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -201,7 +203,7 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
@@ -209,37 +211,29 @@ This is the command to install the nightly with ROCm 6.4 which might have some p
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch nightly, use the following command:
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
|
||||
|
||||
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||
|
||||
```
|
||||
conda install libuv
|
||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
```
|
||||
|
||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@@ -272,6 +266,8 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
#### DirectML (AMD Cards on Windows)
|
||||
|
||||
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### Ascend NPUs
|
||||
@@ -291,6 +287,13 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
|
||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||
3. Launch ComfyUI by running `python main.py`
|
||||
|
||||
#### Iluvatar Corex
|
||||
|
||||
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
@@ -341,7 +344,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
|
||||
|
||||
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
||||
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic_db
|
||||
script_location = app/alembic_db
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
|
||||
@@ -2,13 +2,12 @@ from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from app.assets.database.models import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
175
app/alembic_db/versions/0001_assets.py
Normal file
175
app/alembic_db/versions/0001_assets.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""initial assets schema
|
||||
|
||||
Revision ID: 0001_assets
|
||||
Revises:
|
||||
Create Date: 2025-08-20 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
4
app/assets/__init__.py
Normal file
4
app/assets/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .api.routes import register_assets_system
|
||||
from .scanner import sync_seed_assets
|
||||
|
||||
__all__ = ["sync_seed_assets", "register_assets_system"]
|
||||
225
app/assets/_helpers.py
Normal file
225
app/assets/_helpers.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import contextlib
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .api import schemas_in
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
|
||||
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
|
||||
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
||||
if ts is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def new_scan_id(root: schemas_in.RootType) -> str:
|
||||
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
544
app/assets/api/routes.py
Normal file
544
app/assets/api/routes.py
Normal file
@@ -0,0 +1,544 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
import folder_paths
|
||||
|
||||
from ... import user_manager
|
||||
from .. import manager, scanner
|
||||
from . import schemas_in, schemas_out
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: Optional[user_manager.UserManager] = None
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = await manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
qp = request.rel_url.query
|
||||
query_dict = {}
|
||||
if "include_tags" in qp:
|
||||
query_dict["include_tags"] = qp.getall("include_tags")
|
||||
if "exclude_tags" in qp:
|
||||
query_dict["exclude_tags"] = qp.getall("exclude_tags")
|
||||
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
|
||||
v = qp.get(k)
|
||||
if v is not None:
|
||||
query_dict[k] = v
|
||||
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = await manager.list_assets(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = await manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
resp = web.FileResponse(abs_path)
|
||||
resp.content_type = content_type
|
||||
resp.headers["Content-Disposition"] = cd
|
||||
return resp
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = await manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: Optional[str] = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: Optional[str] = None
|
||||
user_metadata_raw: Optional[str] = None
|
||||
provided_hash: Optional[str] = None
|
||||
provided_hash_exists: Optional[bool] = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: Optional[str] = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = await manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = await manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = await manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = await manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
|
||||
async def set_asset_preview(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.SetPreviewBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.set_asset_preview(
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=body.preview_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (PermissionError, ValueError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = await manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as ve:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = await manager.list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/seed")
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
try:
|
||||
await scanner.sync_seed_assets(body.roots)
|
||||
except Exception:
|
||||
LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response({"synced": True, "roots": body.roots}, status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/schedule")
|
||||
async def schedule_asset_scan(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
states = await scanner.schedule_scans(body.roots)
|
||||
return web.json_response(states.model_dump(mode="json"), status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/scan")
|
||||
async def get_asset_scan_status(request: web.Request) -> web.Response:
|
||||
root = request.query.get("root", "").strip().lower()
|
||||
states = scanner.current_statuses()
|
||||
if root in {"models", "input", "output"}:
|
||||
states = [s for s in states.scans if s.root == root] # type: ignore
|
||||
states = schemas_out.AssetScanStatusResponse(scans=states)
|
||||
return web.json_response(states.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
297
app/assets/api/schemas_in.py
Normal file
297
app/assets/api/schemas_in.py
Normal file
@@ -0,0 +1,297 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: Optional[str] = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: Optional[dict[str, Any]] = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
user_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.tags is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, tags, user_metadata.")
|
||||
if self.tags is not None:
|
||||
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
||||
raise ValueError("Field 'tags' must be an array of strings.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
|
||||
limit: int = Field(100, ge=1, le=1000)
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[RootType] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: Optional[str] = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: Optional[str] = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: Optional[str] = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
115
app/assets/api/schemas_out.py
Normal file
115
app/assets/api/schemas_out.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
last_access_time: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
last_access_time: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssetScanError(BaseModel):
|
||||
path: str
|
||||
message: str
|
||||
at: Optional[str] = Field(None, description="ISO timestamp")
|
||||
|
||||
|
||||
class AssetScanStatus(BaseModel):
|
||||
scan_id: str
|
||||
root: Literal["models", "input", "output"]
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
|
||||
scheduled_at: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
finished_at: Optional[str] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
file_errors: list[AssetScanError] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssetScanStatusResponse(BaseModel):
|
||||
scans: list[AssetScanStatus] = Field(default_factory=list)
|
||||
0
app/assets/database/__init__.py
Normal file
0
app/assets/database/__init__.py
Normal file
25
app/assets/database/helpers/__init__.py
Normal file
25
app/assets/database/helpers/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .bulk_ops import seed_from_paths_batch
|
||||
from .escape_like import escape_like_prefix
|
||||
from .fast_check import fast_asset_file_check
|
||||
from .filters import apply_metadata_filter, apply_tag_filters
|
||||
from .ownership import visible_owner_clause
|
||||
from .projection import is_scalar, project_kv
|
||||
from .tags import (
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"apply_tag_filters",
|
||||
"apply_metadata_filter",
|
||||
"escape_like_prefix",
|
||||
"fast_asset_file_check",
|
||||
"is_scalar",
|
||||
"project_kv",
|
||||
"ensure_tags_exist",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"seed_from_paths_batch",
|
||||
"visible_owner_clause",
|
||||
]
|
||||
230
app/assets/database/helpers/bulk_ops.py
Normal file
230
app/assets/database/helpers/bulk_ops.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
async def seed_from_paths_batch(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
specs: Sequence[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect not in ("sqlite", "postgresql"):
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
await session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
winners_by_path: set[str] = set()
|
||||
if dialect == "sqlite":
|
||||
ins_state = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
else:
|
||||
ins_state = (
|
||||
d_pg.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
if dialect == "sqlite":
|
||||
ins_info = (
|
||||
d_sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
else:
|
||||
ins_info = (
|
||||
d_pg.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
async def bulk_insert_tags_and_meta(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
dialect = session.bind.dialect.name
|
||||
if tag_rows:
|
||||
if dialect == "sqlite":
|
||||
ins_links = (
|
||||
d_sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins_links = (
|
||||
d_pg.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
await session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
if dialect == "sqlite":
|
||||
ins_meta = (
|
||||
d_sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins_meta = (
|
||||
d_pg.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
await session.execute(ins_meta, chunk)
|
||||
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
7
app/assets/database/helpers/escape_like.py
Normal file
7
app/assets/database/helpers/escape_like.py
Normal file
@@ -0,0 +1,7 @@
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
19
app/assets/database/helpers/fast_check.py
Normal file
19
app/assets/database/helpers/fast_check.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: Optional[int],
|
||||
size_db: Optional[int],
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
87
app/assets/database/helpers/filters.py
Normal file
87
app/assets/database/helpers/filters.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from ..._helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Optional[Sequence[str]],
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
12
app/assets/database/helpers/ownership.py
Normal file
12
app/assets/database/helpers/ownership.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from ..models import AssetInfo
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
64
app/assets/database/helpers/projection.py
Normal file
64
app/assets/database/helpers/projection.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
90
app/assets/database/helpers/tags.py
Normal file
90
app/assets/database/helpers/tags.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..._helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
|
||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
await session.execute(ins)
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sa.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sa.literal("missing").label("tag_name"),
|
||||
sa.literal(origin).label("origin"),
|
||||
sa.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sa.not_(
|
||||
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
await session.execute(ins)
|
||||
|
||||
|
||||
async def remove_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
await session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
251
app/assets/database/models.py
Normal file
251
app/assets/database/models.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from .timeutil import utcnow
|
||||
|
||||
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list["AssetCacheState"]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Optional[Asset]] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list["Tag"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
57
app/assets/database/services/__init__.py
Normal file
57
app/assets/database/services/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from .content import (
|
||||
check_fs_asset_exists_quick,
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
ingest_fs_asset,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
redirect_all_references_then_delete_asset,
|
||||
touch_asset_infos_by_fs_path,
|
||||
)
|
||||
from .info import (
|
||||
add_tags_to_asset_info,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_asset_info,
|
||||
replace_asset_info_metadata_projection,
|
||||
set_asset_info_preview,
|
||||
set_asset_info_tags,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_cache_state_by_asset_id,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# queries
|
||||
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
|
||||
"get_cache_state_by_asset_id",
|
||||
"list_cache_states_by_asset_id",
|
||||
"pick_best_live_path",
|
||||
# info
|
||||
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
|
||||
"update_asset_info_full", "replace_asset_info_metadata_projection",
|
||||
"touch_asset_info_by_id", "delete_asset_info_by_id",
|
||||
"add_tags_to_asset_info", "remove_tags_from_asset_info",
|
||||
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
|
||||
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
|
||||
# content
|
||||
"check_fs_asset_exists_quick",
|
||||
"redirect_all_references_then_delete_asset",
|
||||
"compute_hash_and_dedup_for_cache_state",
|
||||
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
|
||||
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
|
||||
"list_cache_states_with_asset_under_prefixes",
|
||||
]
|
||||
721
app/assets/database/services/content.py
Normal file
721
app/assets/database/services/content.py
Normal file
@@ -0,0 +1,721 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from ..._helpers import compute_relative_filename
|
||||
from ...storage import hashing as hashing_mod
|
||||
from ..helpers import (
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .info import replace_asset_info_metadata_projection
|
||||
from .queries import list_cache_states_by_asset_id, pick_best_live_path
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
size_bytes: Optional[int] = None,
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match."""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = (
|
||||
sa.select(sa.literal(True))
|
||||
.select_from(AssetCacheState)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(
|
||||
AssetCacheState.file_path == locator,
|
||||
Asset.hash.isnot(None),
|
||||
AssetCacheState.needs_verify.is_(False),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
conds = []
|
||||
if mtime_ns is not None:
|
||||
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
|
||||
if size_bytes is not None:
|
||||
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
if conds:
|
||||
stmt = stmt.where(*conds)
|
||||
return (await session.execute(stmt)).first() is not None
|
||||
|
||||
|
||||
async def redirect_all_references_then_delete_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
duplicate_asset_id: str,
|
||||
canonical_asset_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
|
||||
|
||||
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
|
||||
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
|
||||
- Otherwise, simply repoint the AssetInfo.asset_id.
|
||||
- Always retarget AssetCacheState rows.
|
||||
- Finally delete the duplicate Asset row.
|
||||
"""
|
||||
if duplicate_asset_id == canonical_asset_id:
|
||||
return
|
||||
|
||||
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
|
||||
dup_infos = (
|
||||
await session.execute(
|
||||
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
|
||||
)
|
||||
).unique().scalars().all()
|
||||
|
||||
for info in dup_infos:
|
||||
# Try to find an existing collision on canonical
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == canonical_asset_id,
|
||||
AssetInfo.owner_id == info.owner_id,
|
||||
AssetInfo.name == info.name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
|
||||
if existing:
|
||||
merged_meta = dict(existing.user_metadata or {})
|
||||
other_meta = info.user_metadata or {}
|
||||
for k, v in other_meta.items():
|
||||
if k not in merged_meta:
|
||||
merged_meta[k] = v
|
||||
if merged_meta != (existing.user_metadata or {}):
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=existing.id,
|
||||
user_metadata=merged_meta,
|
||||
)
|
||||
|
||||
existing_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
from_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
to_add = sorted(from_tags - existing_tags)
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
now = utcnow()
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if existing.preview_id is None and info.preview_id is not None:
|
||||
existing.preview_id = info.preview_id
|
||||
if info.last_access_time and (
|
||||
existing.last_access_time is None or info.last_access_time > existing.last_access_time
|
||||
):
|
||||
existing.last_access_time = info.last_access_time
|
||||
existing.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
|
||||
await session.delete(info)
|
||||
await session.flush()
|
||||
else:
|
||||
# Simple retarget
|
||||
info.asset_id = canonical_asset_id
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# 2) Repoint cache states and previews
|
||||
await session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == duplicate_asset_id)
|
||||
.values(asset_id=canonical_asset_id)
|
||||
)
|
||||
await session.execute(
|
||||
sa.update(AssetInfo)
|
||||
.where(AssetInfo.preview_id == duplicate_asset_id)
|
||||
.values(preview_id=canonical_asset_id)
|
||||
)
|
||||
|
||||
# 3) Remove duplicate Asset
|
||||
dup = await session.get(Asset, duplicate_asset_id)
|
||||
if dup:
|
||||
await session.delete(dup)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def compute_hash_and_dedup_for_cache_state(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
state_id: int,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compute hash for the given cache state, deduplicate, and settle verify cases.
|
||||
|
||||
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
|
||||
"""
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if not state:
|
||||
return None
|
||||
|
||||
path = state.file_path
|
||||
try:
|
||||
if not os.path.isfile(path):
|
||||
# File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too.
|
||||
asset = await session.get(Asset, state.asset_id)
|
||||
await session.delete(state)
|
||||
await session.flush()
|
||||
|
||||
if asset and asset.hash is None:
|
||||
remaining = (
|
||||
await session.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset.id)
|
||||
)
|
||||
).scalar_one()
|
||||
if int(remaining or 0) == 0:
|
||||
await session.delete(asset)
|
||||
await session.flush()
|
||||
else:
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
|
||||
return None
|
||||
|
||||
digest = await hashing_mod.blake3_hash(path)
|
||||
new_hash = f"blake3:{digest}"
|
||||
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
new_size = int(st.st_size)
|
||||
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
# Current asset of this state
|
||||
this_asset = await session.get(Asset, state.asset_id)
|
||||
|
||||
# If the state got orphaned somehow (race), just reattach appropriately.
|
||||
if not this_asset:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
state.asset_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
state.asset_id = new_asset.id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
|
||||
await session.flush()
|
||||
return state.asset_id
|
||||
|
||||
# 1) Seed asset case (hash is NULL): claim or merge into canonical
|
||||
if this_asset.hash is None:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
# Merge seed asset into canonical (safe, collision-aware)
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
|
||||
# No canonical: try to claim the hash; handle races with a SAVEPOINT
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
this_asset.hash = new_hash
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Someone else claimed it concurrently; fetch canonical and merge
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
# If we got here, the integrity error was not about hash uniqueness
|
||||
raise
|
||||
|
||||
# Claimed successfully
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# 2) Verify case for hashed assets
|
||||
if this_asset.hash == new_hash:
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
target_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
target_id = new_asset.id
|
||||
|
||||
state.asset_id = target_id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
|
||||
await session.flush()
|
||||
return target_id
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]:
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0]
|
||||
if session.bind.dialect.name == "postgresql":
|
||||
stmt = (
|
||||
sa.select(AssetCacheState.id)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(Asset.hash.is_(None), path_filter)
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
.distinct(AssetCacheState.asset_id)
|
||||
)
|
||||
else:
|
||||
first_id = sa.func.min(AssetCacheState.id).label("first_id")
|
||||
stmt = (
|
||||
sa.select(first_id)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(Asset.hash.is_(None), path_filter)
|
||||
.group_by(AssetCacheState.asset_id)
|
||||
.order_by(first_id.asc())
|
||||
)
|
||||
return [int(x) for x in (await session.execute(stmt)).scalars().all()]
|
||||
|
||||
|
||||
async def list_verify_candidates_under_prefixes(
|
||||
session: AsyncSession, *, prefixes: Sequence[str]
|
||||
) -> Union[list[int], Sequence[int]]:
|
||||
if not prefixes:
|
||||
return []
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
return (
|
||||
await session.execute(
|
||||
sa.select(AssetCacheState.id)
|
||||
.where(AssetCacheState.needs_verify.is_(True))
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
async def ingest_fs_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: Optional[str] = None,
|
||||
info_name: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
preview_id: Optional[str] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
dialect = session.bind.dialect.name
|
||||
|
||||
if preview_id:
|
||||
if not await session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
if dialect == "sqlite":
|
||||
res = await session.execute(
|
||||
d_sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
await session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
elif dialect == "postgresql":
|
||||
res = await session.execute(
|
||||
d_pg.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[Asset.hash],
|
||||
index_where=Asset.__table__.c.hash.isnot(None),
|
||||
)
|
||||
.returning(Asset.id)
|
||||
)
|
||||
inserted_id = res.scalar_one_or_none()
|
||||
if inserted_id:
|
||||
out["asset_created"] = True
|
||||
asset = await session.get(Asset, inserted_id)
|
||||
else:
|
||||
asset = (
|
||||
await session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
|
||||
res = await session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = await session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
await session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
async def touch_asset_infos_by_fs_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
locator = os.path.abspath(file_path)
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(AssetCacheState)
|
||||
.where(
|
||||
AssetCacheState.asset_id == AssetInfo.asset_id,
|
||||
AssetCacheState.file_path == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None),
|
||||
AssetInfo.last_access_time < ts,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
async def list_cache_states_with_asset_under_prefixes(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefixes: Sequence[str],
|
||||
) -> list[tuple[AssetCacheState, Optional[str], int]]:
|
||||
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
if not p:
|
||||
continue
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base = base + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
if not conds:
|
||||
return []
|
||||
|
||||
rows = (
|
||||
await session.execute(
|
||||
select(AssetCacheState, Asset.hash, Asset.size_bytes)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
|
||||
|
||||
|
||||
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
|
||||
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
|
||||
try:
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
|
||||
if not primary_path:
|
||||
return
|
||||
new_filename = compute_relative_filename(primary_path)
|
||||
if not new_filename:
|
||||
return
|
||||
infos = (
|
||||
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
|
||||
).scalars().all()
|
||||
for info in infos:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") == new_filename:
|
||||
continue
|
||||
updated = dict(current_meta)
|
||||
updated["filename"] = new_filename
|
||||
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
|
||||
except Exception:
|
||||
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)
|
||||
586
app/assets/database/services/info.py
Normal file
586
app/assets/database/services/info.py
Normal file
@@ -0,0 +1,586 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import contains_eager, noload
|
||||
|
||||
from ..._helpers import compute_relative_filename, normalize_tags
|
||||
from ..helpers import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
project_kv,
|
||||
visible_owner_clause,
|
||||
)
|
||||
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .queries import (
|
||||
get_asset_by_hash,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
owner_id: str = "",
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (await session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = await session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def fetch_asset_info_and_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = await session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
async def fetch_asset_info_asset_and_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (await session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
async def create_asset_info_for_existing_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
async def update_asset_info_full(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: Optional[dict],
|
||||
) -> None:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
await session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def touch_asset_info_by_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
await session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((await session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
async def add_tags_to_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
async with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
await nested.rollback()
|
||||
|
||||
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
async def remove_tags_from_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
async def list_tags_with_usage(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (await session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
async def set_asset_info_preview(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not await session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
76
app/assets/database/services/queries.py
Normal file
76
app/assets/database/services/queries.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
row = (
|
||||
await session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
|
||||
return (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
|
||||
return await session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (await session.execute(q)).first() is not None
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def list_cache_states_by_asset_id(
|
||||
session: AsyncSession, *, asset_id: str
|
||||
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
6
app/assets/database/timeutil.py
Normal file
6
app/assets/database/timeutil.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
556
app/assets/manager.py
Normal file
556
app/assets/manager.py
Normal file
@@ -0,0 +1,556 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from comfy_api.internal import async_to_sync
|
||||
|
||||
from ..db import create_session
|
||||
from ._helpers import (
|
||||
ensure_within_base,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.models import Asset
|
||||
from .database.services import (
|
||||
add_tags_to_asset_info,
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
check_fs_asset_exists_quick,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_asset_tags,
|
||||
ingest_fs_asset,
|
||||
list_asset_infos_page,
|
||||
list_cache_states_by_asset_id,
|
||||
list_tags_with_usage,
|
||||
pick_best_live_path,
|
||||
remove_tags_from_asset_info,
|
||||
set_asset_info_preview,
|
||||
touch_asset_info_by_id,
|
||||
touch_asset_infos_by_fs_path,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .storage import hashing
|
||||
|
||||
|
||||
async def asset_exists(*, asset_hash: str) -> bool:
|
||||
async with await create_session() as session:
|
||||
return await asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
|
||||
if tags is None:
|
||||
tags = []
|
||||
try:
|
||||
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset,
|
||||
tags=list(dict.fromkeys([*path_tags, *tags])),
|
||||
file_name=asset_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
abs_path = os.path.abspath(file_path)
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
|
||||
if not size_bytes:
|
||||
return
|
||||
|
||||
async with await create_session() as session:
|
||||
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
|
||||
await touch_asset_infos_by_fs_path(session, file_path=abs_path)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
asset_hash = hashing.blake3_hash_sync(abs_path)
|
||||
|
||||
async with await create_session() as session:
|
||||
await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash="blake3:" + asset_hash,
|
||||
abs_path=abs_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=None,
|
||||
info_name=file_name,
|
||||
tag_origin="automatic",
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def list_assets(
|
||||
*,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
async with await create_session() as session:
|
||||
infos, tag_map, total = await list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
async with await create_session() as session:
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
async def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: Optional[str] = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
try:
|
||||
digest = await hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
async with await create_session() as session:
|
||||
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = await create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
async with await create_session() as session:
|
||||
result = await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
|
||||
|
||||
async def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = await update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
|
||||
|
||||
async def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
await set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = await session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
await session.delete(asset_row)
|
||||
|
||||
await session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
async def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> Optional[schemas_out.AssetCreated]:
|
||||
canonical = hash_str.strip().lower()
|
||||
async with await create_session() as session:
|
||||
asset = await get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = await create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
|
||||
async def list_tags(
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
async with await create_session() as session:
|
||||
rows, total = await list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
|
||||
|
||||
async def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = await add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
await session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
async def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = await remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def _safe_sort_field(requested: Optional[str]) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: Optional[str], fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
501
app/assets/scanner.py
Normal file
501
app/assets/scanner.py
Normal file
@@ -0,0 +1,501 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import folder_paths
|
||||
|
||||
from ..db import create_session
|
||||
from ._helpers import (
|
||||
collect_models_files,
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
list_tree,
|
||||
new_scan_id,
|
||||
prefixes_for_root,
|
||||
ts_to_iso,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.helpers import (
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
fast_asset_file_check,
|
||||
remove_missing_tag_for_asset_id,
|
||||
seed_from_paths_batch,
|
||||
)
|
||||
from .database.models import Asset, AssetCacheState, AssetInfo
|
||||
from .database.services import (
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
list_cache_states_by_asset_id,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SLOW_HASH_CONCURRENCY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanProgress:
|
||||
scan_id: str
|
||||
root: schemas_in.RootType
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
|
||||
scheduled_at: float = field(default_factory=lambda: time.time())
|
||||
started_at: Optional[float] = None
|
||||
finished_at: Optional[float] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
file_errors: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlowQueueState:
|
||||
queue: asyncio.Queue
|
||||
workers: list[asyncio.Task] = field(default_factory=list)
|
||||
closed: bool = False
|
||||
|
||||
|
||||
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
|
||||
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
|
||||
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
|
||||
|
||||
|
||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
scans = []
|
||||
for root in schemas_in.ALLOWED_ROOTS:
|
||||
prog = PROGRESS_BY_ROOT.get(root)
|
||||
if not prog:
|
||||
continue
|
||||
scans.append(_scan_progress_to_scan_status_model(prog))
|
||||
return schemas_out.AssetScanStatusResponse(scans=scans)
|
||||
|
||||
|
||||
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
|
||||
results: list[ScanProgress] = []
|
||||
for root in roots:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
results.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
state = SlowQueueState(queue=asyncio.Queue())
|
||||
SLOW_STATE_BY_ROOT[root] = state
|
||||
RUNNING_TASKS[root] = asyncio.create_task(
|
||||
_run_hash_verify_pipeline(root, prog, state),
|
||||
name=f"asset-scan:{root}",
|
||||
)
|
||||
results.append(prog)
|
||||
return _status_response_for(results)
|
||||
|
||||
|
||||
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
|
||||
t_total = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as ex:
|
||||
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
ap = os.path.abspath(p)
|
||||
if ap in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
st = os.stat(ap, follow_symlinks=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not st.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(ap)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": ap,
|
||||
"size_bytes": st.st_size,
|
||||
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(ap),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
|
||||
if not specs:
|
||||
return
|
||||
async with await create_session() as sess:
|
||||
if tag_pool:
|
||||
await ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
await sess.commit()
|
||||
finally:
|
||||
LOGGER.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_total,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
|
||||
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
|
||||
|
||||
|
||||
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
|
||||
return schemas_out.AssetScanStatus(
|
||||
scan_id=progress.scan_id,
|
||||
root=progress.root,
|
||||
status=progress.status,
|
||||
scheduled_at=ts_to_iso(progress.scheduled_at),
|
||||
started_at=ts_to_iso(progress.started_at),
|
||||
finished_at=ts_to_iso(progress.finished_at),
|
||||
discovered=progress.discovered,
|
||||
processed=progress.processed,
|
||||
file_errors=[
|
||||
schemas_out.AssetScanError(
|
||||
path=e.get("path", ""),
|
||||
message=e.get("message", ""),
|
||||
at=e.get("at"),
|
||||
)
|
||||
for e in (progress.file_errors or [])
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
prog.status = "running"
|
||||
prog.started_at = time.time()
|
||||
try:
|
||||
prefixes = prefixes_for_root(root)
|
||||
|
||||
await _fast_db_consistency_pass(root)
|
||||
|
||||
# collect candidates from DB
|
||||
async with await create_session() as sess:
|
||||
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
# dedupe: prioritize verification first
|
||||
seen = set()
|
||||
ordered: list[int] = []
|
||||
for lst in (verify_ids, unhashed_ids):
|
||||
for sid in lst:
|
||||
if sid not in seen:
|
||||
seen.add(sid)
|
||||
ordered.append(sid)
|
||||
|
||||
prog.discovered = len(ordered)
|
||||
|
||||
# queue up work
|
||||
for sid in ordered:
|
||||
await state.queue.put(sid)
|
||||
state.closed = True
|
||||
_start_state_workers(root, prog, state)
|
||||
await _await_state_workers_then_finish(root, prog, state)
|
||||
except asyncio.CancelledError:
|
||||
prog.status = "cancelled"
|
||||
raise
|
||||
except Exception as exc:
|
||||
_append_error(prog, path="", message=str(exc))
|
||||
prog.status = "failed"
|
||||
prog.finished_at = time.time()
|
||||
LOGGER.exception("Asset scan failed for %s", root)
|
||||
finally:
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
|
||||
"""
|
||||
Detect missing files quickly and toggle 'missing' tag per asset_id.
|
||||
|
||||
Rules:
|
||||
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
|
||||
- We consider ALL cache states of the asset (across roots) before tagging.
|
||||
"""
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
elif root == "input":
|
||||
bases = [folder_paths.get_input_directory()]
|
||||
else:
|
||||
bases = [folder_paths.get_output_directory()]
|
||||
|
||||
try:
|
||||
async with await create_session() as sess:
|
||||
# state + hash + size for the current root
|
||||
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
|
||||
|
||||
# Track fast_ok within the scanned root and whether the asset is hashed
|
||||
by_asset: dict[str, dict[str, bool]] = {}
|
||||
for state, a_hash, size_db in rows:
|
||||
aid = state.asset_id
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
|
||||
by_asset[aid] = acc
|
||||
try:
|
||||
if acc["hashed"]:
|
||||
st = os.stat(state.file_path, follow_symlinks=True)
|
||||
if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st):
|
||||
acc["any_fast_ok_here"] = True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
_append_error(prog, path=state.file_path, message=str(e))
|
||||
|
||||
# Decide per asset, considering ALL its states (not just this root)
|
||||
for aid, acc in by_asset.items():
|
||||
try:
|
||||
if not acc["hashed"]:
|
||||
# Never tag seed assets as missing
|
||||
continue
|
||||
|
||||
any_fast_ok_global = acc["any_fast_ok_here"]
|
||||
if not any_fast_ok_global:
|
||||
# Check other states outside this root
|
||||
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
|
||||
for st in others:
|
||||
try:
|
||||
any_fast_ok_global = fast_asset_file_check(
|
||||
mtime_db=st.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(st.file_path, follow_symlinks=True),
|
||||
)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
if any_fast_ok_global:
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
else:
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
except Exception as ex:
|
||||
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
|
||||
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
_append_error(prog, path="", message=f"reconcile failed: {e}")
|
||||
|
||||
|
||||
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
if state.workers:
|
||||
return
|
||||
|
||||
async def _worker(_wid: int):
|
||||
while True:
|
||||
sid = await state.queue.get()
|
||||
try:
|
||||
if sid is None:
|
||||
return
|
||||
try:
|
||||
async with await create_session() as sess:
|
||||
# Optional: fetch path for better error messages
|
||||
st = await sess.get(AssetCacheState, sid)
|
||||
try:
|
||||
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
path = st.file_path if st else f"state:{sid}"
|
||||
_append_error(prog, path=path, message=str(e))
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
prog.processed += 1
|
||||
finally:
|
||||
state.queue.task_done()
|
||||
|
||||
state.workers = [
|
||||
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
|
||||
for i in range(SLOW_HASH_CONCURRENCY)
|
||||
]
|
||||
|
||||
async def _close_when_ready():
|
||||
while not state.closed:
|
||||
await asyncio.sleep(0.05)
|
||||
for _ in range(SLOW_HASH_CONCURRENCY):
|
||||
await state.queue.put(None)
|
||||
|
||||
asyncio.create_task(_close_when_ready())
|
||||
|
||||
|
||||
async def _await_state_workers_then_finish(
|
||||
root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState
|
||||
) -> None:
|
||||
if state.workers:
|
||||
await asyncio.gather(*state.workers, return_exceptions=True)
|
||||
await _reconcile_missing_tags_for_root(root, prog)
|
||||
prog.finished_at = time.time()
|
||||
prog.status = "completed"
|
||||
|
||||
|
||||
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
|
||||
prog.file_errors.append({
|
||||
"path": path,
|
||||
"message": message,
|
||||
"at": ts_to_iso(time.time()),
|
||||
})
|
||||
|
||||
|
||||
async def _fast_db_consistency_pass(
|
||||
root: schemas_in.RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> Optional[set[str]]:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
async with await create_session() as sess:
|
||||
rows = (
|
||||
await sess.execute(
|
||||
sa.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = await sess.get(Asset, aid)
|
||||
if asset:
|
||||
await sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
await sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
0
app/assets/storage/__init__.py
Normal file
0
app/assets/storage/__init__.py
Normal file
72
app/assets/storage/hashing.py
Normal file
72
app/assets/storage/hashing.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import IO, Union
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
|
||||
|
||||
|
||||
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
|
||||
"""Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
orig_pos = None
|
||||
if hasattr(file_obj, "tell"):
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
if hasattr(file_obj, "seek"):
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
if hasattr(file_obj, "seek") and orig_pos is not None:
|
||||
file_obj.seek(orig_pos)
|
||||
|
||||
|
||||
def blake3_hash_sync(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj_sync(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
@@ -1,112 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
@@ -1,14 +0,0 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
# TODO: Define models here
|
||||
255
app/db.py
Normal file
255
app/db.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
ENGINE: Optional[AsyncEngine] = None
|
||||
SESSION: Optional[async_sessionmaker] = None
|
||||
|
||||
|
||||
def _root_paths():
|
||||
"""Resolve alembic.ini and migrations script folder."""
|
||||
root_path = os.path.abspath(os.path.dirname(__file__))
|
||||
config_path = os.path.abspath(os.path.join(root_path, "../alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
return config_path, scripts_path
|
||||
|
||||
|
||||
def _absolutize_sqlite_url(db_url: str) -> str:
|
||||
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url
|
||||
|
||||
db_path: str = u.database or ""
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
|
||||
if not os.path.isabs(db_path):
|
||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||
u = u.set(database=db_path)
|
||||
return str(u)
|
||||
|
||||
|
||||
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
|
||||
"""
|
||||
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
|
||||
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
|
||||
Returns: (normalized_url, is_memory)
|
||||
"""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url, False
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url, False
|
||||
|
||||
db = u.database or ""
|
||||
if db == ":memory:":
|
||||
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
|
||||
return str(u), True
|
||||
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
|
||||
if "uri=true" not in db:
|
||||
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
|
||||
return str(u), True
|
||||
return str(u), False
|
||||
|
||||
|
||||
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||
"""Return the on-disk path for a SQLite URL, else None."""
|
||||
try:
|
||||
u = make_url(sync_url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return None
|
||||
db_path = u.database
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return None # Not a real file if it is a URI like "file:...?"
|
||||
return db_path
|
||||
|
||||
|
||||
def _get_alembic_config(sync_url: str) -> Config:
|
||||
"""Prepare Alembic Config with script location and DB URL."""
|
||||
config_path, scripts_path = _root_paths()
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", sync_url)
|
||||
return cfg
|
||||
|
||||
|
||||
async def init_db_engine() -> None:
|
||||
"""Initialize async engine + sessionmaker and run migrations to head.
|
||||
|
||||
This must be called once on application startup before any DB usage.
|
||||
"""
|
||||
global ENGINE, SESSION
|
||||
|
||||
if ENGINE is not None:
|
||||
return
|
||||
|
||||
raw_url = args.database_url
|
||||
if not raw_url:
|
||||
raise RuntimeError("Database URL is not configured.")
|
||||
|
||||
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
|
||||
db_url = _absolutize_sqlite_url(db_url)
|
||||
|
||||
# Prepare async engine
|
||||
connect_args = {}
|
||||
if db_url.startswith("sqlite"):
|
||||
connect_args = {
|
||||
"check_same_thread": False,
|
||||
"timeout": 12,
|
||||
}
|
||||
if is_mem:
|
||||
connect_args["uri"] = True
|
||||
|
||||
ENGINE = create_async_engine(
|
||||
db_url,
|
||||
connect_args=connect_args,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Enforce SQLite pragmas on the async engine
|
||||
if db_url.startswith("sqlite"):
|
||||
async with ENGINE.begin() as conn:
|
||||
if not is_mem:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
|
||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||
|
||||
await _run_migrations(database_url=db_url, connect_args=connect_args)
|
||||
|
||||
SESSION = async_sessionmaker(
|
||||
bind=ENGINE,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def _run_migrations(database_url: str, connect_args: dict) -> None:
|
||||
if database_url.find("postgresql+psycopg") == -1:
|
||||
"""SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||
u = make_url(database_url)
|
||||
driver = u.drivername
|
||||
if not driver.startswith("sqlite+aiosqlite"):
|
||||
raise ValueError(f"Unsupported DB driver: {driver}")
|
||||
database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite")))
|
||||
database_url = _absolutize_sqlite_url(database_url)
|
||||
|
||||
cfg = _get_alembic_config(database_url)
|
||||
engine = create_engine(database_url, future=True, connect_args=connect_args)
|
||||
with engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
LOGGER.warning("Alembic: no target revision found.")
|
||||
return
|
||||
|
||||
if current_rev == target_rev:
|
||||
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
||||
return
|
||||
|
||||
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
||||
|
||||
# Optional backup for SQLite file DBs
|
||||
backup_path = None
|
||||
sqlite_path = _get_sqlite_file_path(database_url)
|
||||
if sqlite_path and os.path.exists(sqlite_path):
|
||||
backup_path = sqlite_path + ".bkp"
|
||||
try:
|
||||
shutil.copy(sqlite_path, backup_path)
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, target_rev)
|
||||
except Exception:
|
||||
if backup_path and os.path.exists(backup_path):
|
||||
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
||||
try:
|
||||
shutil.copy(backup_path, sqlite_path) # restore
|
||||
os.remove(backup_path)
|
||||
except Exception as re:
|
||||
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
||||
else:
|
||||
LOGGER.exception("Error upgrading database, backup is not available.")
|
||||
raise
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""Return the global async engine (initialized after init_db_engine())."""
|
||||
if ENGINE is None:
|
||||
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
||||
return ENGINE
|
||||
|
||||
|
||||
def get_session_maker():
|
||||
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
||||
if SESSION is None:
|
||||
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
||||
return SESSION
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope():
|
||||
"""Async context manager for a unit of work:
|
||||
|
||||
async with session_scope() as sess:
|
||||
... use sess ...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
async with maker() as sess:
|
||||
try:
|
||||
yield sess
|
||||
await sess.commit()
|
||||
except Exception:
|
||||
await sess.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def create_session():
|
||||
"""Convenience helper to acquire a single AsyncSession instance.
|
||||
|
||||
Typical usage:
|
||||
async with (await create_session()) as sess:
|
||||
...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
return maker()
|
||||
@@ -29,18 +29,48 @@ def frontend_install_warning_message():
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
""".strip()
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||
return bool(re.match(pattern, version))
|
||||
|
||||
def get_installed_frontend_version():
|
||||
"""Get the currently installed frontend package version."""
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
return frontend_version_str
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-frontend-package=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version_str = get_installed_frontend_version()
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
required_frontend_str = get_required_frontend_version()
|
||||
required_frontend = parse_version(required_frontend_str)
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
@@ -166,10 +196,35 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
"""
|
||||
A class to manage ComfyUI frontend versions and installations.
|
||||
|
||||
This class handles the initialization and management of different frontend versions,
|
||||
including the default frontend from the pip package and custom frontend versions
|
||||
from GitHub repositories.
|
||||
|
||||
Attributes:
|
||||
CUSTOM_FRONTENDS_ROOT (str): The root directory where custom frontend versions are stored.
|
||||
"""
|
||||
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def get_required_frontend_version(cls) -> str:
|
||||
"""Get the required frontend package version."""
|
||||
return get_required_frontend_version()
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the default frontend installation from the pip package.
|
||||
|
||||
Returns:
|
||||
str: The path to the default frontend static files.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-frontend-package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_frontend_package
|
||||
|
||||
@@ -190,6 +245,15 @@ comfyui-frontend-package is not installed.
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
"""
|
||||
Get the path to the workflow templates.
|
||||
|
||||
Returns:
|
||||
str: The path to the workflow templates directory.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the comfyui-workflow-templates package is not installed.
|
||||
"""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@@ -225,11 +289,16 @@ comfyui-workflow-templates is not installed.
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Parse a version string into its components.
|
||||
|
||||
The version string should be in the format: 'owner/repo@version'
|
||||
where version can be either a semantic version (v1.2.3) or 'latest'.
|
||||
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
tuple[str, str, str]: A tuple containing (owner, repo, version).
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
@@ -246,18 +315,22 @@ comfyui-workflow-templates is not installed.
|
||||
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||
) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
Initialize a frontend version without error handling.
|
||||
|
||||
This method attempts to initialize a specific frontend version, either from
|
||||
the default pip package or from a custom GitHub repository. It will download
|
||||
and extract the frontend files if necessary.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
provider (FrontEndProvider, optional): The provider to use for custom frontends.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
Exception: If there is an error during initialization (e.g., network timeout,
|
||||
invalid URL, or missing assets).
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
check_frontend_version()
|
||||
@@ -309,13 +382,17 @@ comfyui-workflow-templates is not installed.
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
Initialize a frontend version with error handling.
|
||||
|
||||
This is the main method to initialize a frontend version. It wraps init_frontend_unsafe
|
||||
with error handling, falling back to the default frontend if initialization fails.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
version_string (str): The version string specifying which frontend to use.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
str: The path to the initialized frontend. If initialization fails,
|
||||
returns the path to the default frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
|
||||
@@ -130,10 +130,21 @@ class ModelFileManager:
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
full_path = os.path.join(dirpath, file_name)
|
||||
relative_path = os.path.relpath(full_path, directory)
|
||||
|
||||
# Get file metadata
|
||||
file_info = {
|
||||
"name": relative_path,
|
||||
"pathIndex": pathIndex,
|
||||
"modified": os.path.getmtime(full_path), # Add modification time
|
||||
"created": os.path.getctime(full_path), # Add creation time
|
||||
"size": os.path.getsize(full_path) # Add file size
|
||||
}
|
||||
result.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
@@ -144,7 +155,7 @@ class ModelFileManager:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
return result, dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
@@ -20,13 +20,15 @@ class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
created: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
"modified": os.path.getmtime(path),
|
||||
"created": os.path.getctime(path)
|
||||
}
|
||||
|
||||
|
||||
@@ -361,10 +363,17 @@ class UserManager():
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
status=400,
|
||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||
)
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
|
||||
91
comfy/audio_encoders/audio_encoders.py
Normal file
91
comfy/audio_encoders/audio_encoders.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from .wav2vec2 import Wav2Vec2Model
|
||||
from .whisper import WhisperLargeV3
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import logging
|
||||
import torchaudio
|
||||
|
||||
|
||||
class AudioEncoderModel():
|
||||
def __init__(self, config):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
model_type = config.pop("model_type")
|
||||
model_config = dict(config)
|
||||
model_config.update({
|
||||
"dtype": self.dtype,
|
||||
"device": offload_device,
|
||||
"operations": comfy.ops.manual_cast
|
||||
})
|
||||
|
||||
if model_type == "wav2vec2":
|
||||
self.model = Wav2Vec2Model(**model_config)
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_audio(self, audio, sample_rate):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||
out, all_layers = self.model(audio.to(self.load_device))
|
||||
outputs = {}
|
||||
outputs["encoded_audio"] = out
|
||||
outputs["encoded_audio_all_layers"] = all_layers
|
||||
outputs["audio_samples"] = audio.shape[2]
|
||||
return outputs
|
||||
|
||||
|
||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||
if "encoder.layer_norm.bias" in sd: #wav2vec2
|
||||
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
||||
if embed_dim == 1024:# large
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 1024,
|
||||
"num_heads": 16,
|
||||
"num_layers": 24,
|
||||
"conv_norm": True,
|
||||
"conv_bias": True,
|
||||
"do_normalize": True,
|
||||
"do_stable_layer_norm": True
|
||||
}
|
||||
elif embed_dim == 768: # base
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 768,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"conv_norm": False,
|
||||
"conv_bias": False,
|
||||
"do_normalize": False, # chinese-wav2vec2-base has this False
|
||||
"do_stable_layer_norm": False
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
||||
elif "model.encoder.embed_positions.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
|
||||
config = {
|
||||
"model_type": "whisper3",
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder not supported.")
|
||||
|
||||
audio_encoder = AudioEncoderModel(config)
|
||||
m, u = audio_encoder.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing audio encoder: {}".format(m))
|
||||
if len(u) > 0:
|
||||
logging.warning("unexpected audio encoder: {}".format(u))
|
||||
|
||||
return audio_encoder
|
||||
252
comfy/audio_encoders/wav2vec2.py
Normal file
252
comfy/audio_encoders/wav2vec2.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
|
||||
|
||||
class LayerNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||
|
||||
class LayerGroupNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x))
|
||||
|
||||
class ConvNoNorm(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class ConvFeatureEncoder(nn.Module):
|
||||
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if conv_norm:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
else:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class FeatureProjection(nn.Module):
|
||||
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer_norm(x)
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionalConvEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=groups,
|
||||
)
|
||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
x = self.conv(x)[:, :, :-1]
|
||||
x = self.activation(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.pos_conv_embed(x)
|
||||
all_x = ()
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x, mask)
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
assert (mask is None) # TODO?
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.num_heads)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
||||
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.intermediate_dense(x)
|
||||
x = torch.nn.functional.gelu(x)
|
||||
x = self.output_dense(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = residual + x
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
return self.final_layer_norm(x + self.feed_forward(x))
|
||||
else:
|
||||
return x + self.feed_forward(self.final_layer_norm(x))
|
||||
|
||||
|
||||
class Wav2Vec2Model(nn.Module):
|
||||
"""Complete Wav2Vec 2.0 model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=1024,
|
||||
final_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
conv_norm=True,
|
||||
conv_bias=True,
|
||||
do_normalize=True,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
conv_dim = 512
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
if self.do_normalize:
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
|
||||
features = self.feature_extractor(x)
|
||||
features = self.feature_projection(features)
|
||||
batch_size, seq_len, _ = features.shape
|
||||
|
||||
x, all_x = self.encoder(features)
|
||||
return x, all_x
|
||||
186
comfy/audio_encoders/whisper.py
Executable file
186
comfy/audio_encoders/whisper.py
Executable file
@@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from typing import Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.ops
|
||||
|
||||
class WhisperFeatureExtractor(nn.Module):
|
||||
def __init__(self, n_mels=128, device=None):
|
||||
super().__init__()
|
||||
self.sample_rate = 16000
|
||||
self.n_fft = 400
|
||||
self.hop_length = 160
|
||||
self.n_mels = n_mels
|
||||
self.chunk_length = 30
|
||||
self.n_samples = 480000
|
||||
|
||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
).to(device)
|
||||
|
||||
def __call__(self, audio):
|
||||
audio = torch.mean(audio, dim=1)
|
||||
batch_size = audio.shape[0]
|
||||
processed_audio = []
|
||||
|
||||
for i in range(batch_size):
|
||||
aud = audio[i]
|
||||
if aud.shape[0] > self.n_samples:
|
||||
aud = aud[:self.n_samples]
|
||||
elif aud.shape[0] < self.n_samples:
|
||||
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
|
||||
processed_audio.append(aud)
|
||||
|
||||
audio = torch.stack(processed_audio)
|
||||
|
||||
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
|
||||
|
||||
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
|
||||
log_mel_spec = (log_mel_spec + 4.0) / 4.0
|
||||
|
||||
return log_mel_spec
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert d_model % n_heads == 0
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.d_k = d_model // n_heads
|
||||
|
||||
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
|
||||
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
|
||||
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x = self.self_attn(x, x, x, attention_mask)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = F.gelu(x)
|
||||
x = self.fc2(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_ctx: int = 1500,
|
||||
n_state: int = 1280,
|
||||
n_head: int = 20,
|
||||
n_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
|
||||
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
|
||||
|
||||
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
||||
|
||||
all_x = ()
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x)
|
||||
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class WhisperLargeV3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_audio_ctx: int = 1500,
|
||||
n_audio_state: int = 1280,
|
||||
n_audio_head: int = 20,
|
||||
n_audio_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
|
||||
|
||||
self.encoder = AudioEncoder(
|
||||
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, audio):
|
||||
mel = self.feature_extractor(audio)
|
||||
x, all_x = self.encoder(mel)
|
||||
return x, all_x
|
||||
@@ -49,7 +49,8 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
@@ -131,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
@@ -140,10 +143,12 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
@@ -151,6 +156,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
@@ -206,7 +212,8 @@ parser.add_argument(
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
all_intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
if intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
@@ -97,7 +107,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
|
||||
@@ -50,7 +50,13 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@@ -68,12 +74,18 @@ class ClipVisionModel():
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
outputs["all_hidden_states"] = all_hs
|
||||
else:
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
@@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
|
||||
# Dinov2
|
||||
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
@@ -10,12 +11,15 @@ class CONDRegular:
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@@ -29,14 +33,14 @@ class CONDRegular:
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
def process_cond(self, batch_size, area, **kwargs):
|
||||
data = self.cond
|
||||
if area is not None:
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
@@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
@@ -92,10 +99,10 @@ class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
|
||||
540
comfy/context_windows.py
Normal file
540
comfy/context_windows.py
Normal file
@@ -0,0 +1,540 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
|
||||
class ContextWindowABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get torch.Tensor applicable to current window.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
class ContextHandlerABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
@dataclass
|
||||
class ContextFuseMethod:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
self.context_overlap = context_overlap
|
||||
self.context_stride = context_stride
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
if control.previous_controlnet is not None:
|
||||
self.prepare_control_objects(control.previous_controlnet, device)
|
||||
return control
|
||||
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
||||
if cond_in is None:
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
||||
for key in actual_cond:
|
||||
try:
|
||||
cond_item = actual_cond[key]
|
||||
if isinstance(cond_item, torch.Tensor):
|
||||
# check that tensor is the expected length - x.size(0)
|
||||
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
||||
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
||||
actual_cond_item = window.get_tensor(cond_item)
|
||||
resized_actual_cond[key] = actual_cond_item.to(device)
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item.to(device)
|
||||
# look for control
|
||||
elif key == "control":
|
||||
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
||||
elif isinstance(cond_item, dict):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
resized_actual_cond[key] = new_cond_item
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item
|
||||
finally:
|
||||
del cond_item # just in case to prevent VRAM issues
|
||||
resized_cond.append(resized_actual_cond)
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
for pos, idx in enumerate(window.index_list):
|
||||
# bias is the influence of a specific index in relation to the whole context window
|
||||
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
||||
bias = max(1e-2, bias)
|
||||
# take weighted average relative to total bias of current idx
|
||||
for i in range(len(sub_conds_out)):
|
||||
bias_total = biases_final[i][idx]
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
window.add_window(counts_final[i], weights_tensor)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||
"ContextWindows_prepare_sampling",
|
||||
_prepare_sampling_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
for _ in range(dim):
|
||||
weights_tensor = weights_tensor.unsqueeze(0)
|
||||
for _ in range(total_dims - dim - 1):
|
||||
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||
return weights_tensor
|
||||
|
||||
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||
total_dims = len(x_in.shape)
|
||||
shape = []
|
||||
for _ in range(dim):
|
||||
shape.append(1)
|
||||
shape.append(x_in.shape[dim])
|
||||
for _ in range(total_dims - dim - 1):
|
||||
shape.append(1)
|
||||
return shape
|
||||
|
||||
class ContextSchedules:
|
||||
UNIFORM_LOOPED = "looped_uniform"
|
||||
UNIFORM_STANDARD = "standard_uniform"
|
||||
STATIC_STANDARD = "standard_static"
|
||||
BATCHED = "batched"
|
||||
|
||||
|
||||
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
||||
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames < handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
return windows
|
||||
|
||||
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||
# instead, they get shifted to the corresponding end of the frames.
|
||||
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# first, obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (-handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||
delete_idxs = []
|
||||
win_i = 0
|
||||
while win_i < len(windows):
|
||||
# if window is rolls over itself, need to shift it
|
||||
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||
if is_roll:
|
||||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||
# check if next window (cyclical) is missing roll_val
|
||||
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||
# need to insert new window here - just insert window starting at roll_val
|
||||
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||
# delete window if it's not unique
|
||||
for pre_i in range(0, win_i):
|
||||
if windows[win_i] == windows[pre_i]:
|
||||
delete_idxs.append(win_i)
|
||||
break
|
||||
win_i += 1
|
||||
|
||||
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||
delete_idxs.reverse()
|
||||
for i in delete_idxs:
|
||||
windows.pop(i)
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows
|
||||
delta = handler.context_length - handler.context_overlap
|
||||
for start_idx in range(0, num_frames, delta):
|
||||
# if past the end of frames, move start_idx back to allow same context_length
|
||||
ending = start_idx + handler.context_length
|
||||
if ending >= num_frames:
|
||||
final_delta = ending - num_frames
|
||||
final_start_idx = start_idx - final_delta
|
||||
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
||||
break
|
||||
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows;
|
||||
# no overlap, just cut up based on context_length;
|
||||
# last window size will be different if num_frames % opts.context_length != 0
|
||||
for start_idx in range(0, num_frames, handler.context_length):
|
||||
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
||||
return [list(range(num_frames))]
|
||||
|
||||
|
||||
CONTEXT_MAPPING = {
|
||||
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
||||
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
||||
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
||||
ContextSchedules.BATCHED: create_windows_batched,
|
||||
}
|
||||
|
||||
|
||||
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
func = CONTEXT_MAPPING.get(context_schedule, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
# weight is the same for all
|
||||
return [1.0] * length
|
||||
|
||||
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
# weight is based on the distance away from the edge of the context window;
|
||||
# based on weighted average concept in FreeNoise paper
|
||||
if length % 2 == 0:
|
||||
max_weight = length // 2
|
||||
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
||||
else:
|
||||
max_weight = (length + 1) // 2
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
FLAT = "flat"
|
||||
PYRAMID = "pyramid"
|
||||
RELATIVE = "relative"
|
||||
OVERLAP_LINEAR = "overlap-linear"
|
||||
|
||||
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
||||
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
||||
|
||||
|
||||
FUSE_MAPPING = {
|
||||
ContextFuseMethods.FLAT: create_weights_flat,
|
||||
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
||||
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
||||
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||
}
|
||||
|
||||
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||
func = FUSE_MAPPING.get(fuse_method, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||
return ContextFuseMethod(fuse_method, func)
|
||||
|
||||
# Returns fraction that has denominator that is a power of 2
|
||||
def ordered_halving(val):
|
||||
# get binary value, padded with 0s for 64 bits
|
||||
bin_str = f"{val:064b}"
|
||||
# flip binary value, padding included
|
||||
bin_flip = bin_str[::-1]
|
||||
# convert binary to int
|
||||
as_int = int(bin_flip, 2)
|
||||
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
||||
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
||||
return as_int / (1 << 64)
|
||||
|
||||
|
||||
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||
all_indexes = list(range(num_frames))
|
||||
for w in windows:
|
||||
for val in w:
|
||||
try:
|
||||
all_indexes.remove(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return all_indexes
|
||||
|
||||
|
||||
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||
prev_val = -1
|
||||
for i, val in enumerate(window):
|
||||
val = val % num_frames
|
||||
if val < prev_val:
|
||||
return True, i
|
||||
prev_val = val
|
||||
return False, -1
|
||||
|
||||
|
||||
def shift_window_to_start(window: list[int], num_frames: int):
|
||||
start_val = window[0]
|
||||
for i in range(len(window)):
|
||||
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||
# 2) add num_frames and take modulus to get adjusted vals
|
||||
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||
|
||||
|
||||
def shift_window_to_end(window: list[int], num_frames: int):
|
||||
# 1) shift window to start
|
||||
shift_window_to_start(window, num_frames)
|
||||
end_val = window[-1]
|
||||
end_delta = num_frames - end_val - 1
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
@@ -28,6 +28,7 @@ import comfy.model_detection
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
import comfy.model_base
|
||||
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
@@ -35,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -43,7 +45,6 @@ if TYPE_CHECKING:
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@@ -236,11 +237,11 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = None
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
compression_ratio *= self.vae.spacial_compression_encode()
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
@@ -252,7 +253,10 @@ class ControlNet(ControlBase):
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
|
||||
if c.ndim < self.cond_hint.ndim:
|
||||
c = c.unsqueeze(2)
|
||||
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
@@ -265,12 +269,12 @@ class ControlNet(ControlBase):
|
||||
for c in self.extra_conds:
|
||||
temp = cond.get(c, None)
|
||||
if temp is not None:
|
||||
extra[c] = temp.to(dtype)
|
||||
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
|
||||
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
@@ -582,6 +586,22 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
|
||||
|
||||
extra_condition_channels = 0
|
||||
concat_mask = False
|
||||
if control_latent_channels == 68: #inpaint controlnet
|
||||
extra_condition_channels = control_latent_channels - 64
|
||||
concat_mask = True
|
||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
extra_conds = []
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -655,8 +675,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
|
||||
@@ -1,55 +1,10 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||
import comfy.ops
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * torch.nn.functional.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GatedCrossAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
|
||||
@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
class Dinov2MLP(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
mlp_ratio = 4
|
||||
hidden_features = int(hidden_size * mlp_ratio)
|
||||
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.fc1(hidden_state)
|
||||
hidden_state = torch.nn.functional.gelu(hidden_state)
|
||||
hidden_state = self.fc2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
else:
|
||||
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
for i, layer in enumerate(self.layer):
|
||||
x = layer(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
|
||||
22
comfy/image_encoders/dino2_large.json
Normal file
22
comfy/image_encoders/dino2_large.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"hidden_size": 1024,
|
||||
"use_mask_token": true,
|
||||
"patch_size": 14,
|
||||
"image_size": 518,
|
||||
"num_channels": 3,
|
||||
"num_attention_heads": 16,
|
||||
"initializer_range": 0.02,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_hidden_layers": 24,
|
||||
"layer_norm_eps": 1e-6,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": false,
|
||||
"layerscale_value": 1.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
121
comfy/k_diffusion/sa_solver.py
Normal file
121
comfy/k_diffusion/sa_solver.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
|
||||
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
|
||||
# Codebase ref: https://github.com/scxue/SA-Solver
|
||||
|
||||
import math
|
||||
from typing import Union, Callable
|
||||
import torch
|
||||
|
||||
|
||||
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
|
||||
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
|
||||
|
||||
Integral of exp((1 + tau^2) * x) * x^p dx
|
||||
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
|
||||
with base case p=0 where integral equals product_terms[0].
|
||||
|
||||
where
|
||||
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
|
||||
|
||||
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
|
||||
Return coefficients used by the SA-Solver in data prediction mode.
|
||||
|
||||
Args:
|
||||
s: Start time s.
|
||||
t: End time t.
|
||||
solver_order: Current order of the solver.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
|
||||
Returns:
|
||||
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
|
||||
"""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = t - s
|
||||
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
|
||||
|
||||
# product_terms after factoring out exp((1 + tau^2) * t)
|
||||
# Includes (1 + tau^2) factor from outside the integral
|
||||
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
|
||||
|
||||
# Lower triangular recursive coefficient matrix
|
||||
# Accumulates recursive coefficients based on p / (1 + tau^2)
|
||||
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
|
||||
log_factorial = (p + 1).lgamma()
|
||||
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
|
||||
if tau_t > 0:
|
||||
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
|
||||
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
|
||||
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
|
||||
|
||||
return recursive_coeff_mat @ product_terms_factored
|
||||
|
||||
|
||||
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = lambda_t - lambda_s
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
if is_corrector_step:
|
||||
# Simplified 1-step (order-2) corrector
|
||||
b_1 = alpha_t * (0.5 * tau_mul * h)
|
||||
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
|
||||
else:
|
||||
# Simplified 2-step predictor
|
||||
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
|
||||
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
|
||||
return torch.stack([b_2, b_1])
|
||||
|
||||
|
||||
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
|
||||
|
||||
The solver order corresponds to the number of input lambdas (half-logSNR points).
|
||||
|
||||
Args:
|
||||
sigma_next: Sigma at end time t.
|
||||
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
|
||||
lambda_s: Lambda at start time s.
|
||||
lambda_t: Lambda at end time t.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
simple_order_2: Whether to enable the simple order-2 scheme.
|
||||
is_corrector_step: Flag for corrector step in simple order-2 mode.
|
||||
|
||||
Returns:
|
||||
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
|
||||
"""
|
||||
num_timesteps = curr_lambdas.shape[0]
|
||||
|
||||
if simple_order_2 and num_timesteps == 2:
|
||||
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
|
||||
|
||||
# Compute coefficients by solving a linear system from Lagrange basis interpolation
|
||||
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
|
||||
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
|
||||
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
|
||||
|
||||
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
|
||||
# = sigma_t * exp(lambda_t) = alpha_t
|
||||
# exp((1 + tau^2) * lambda_t) is extracted from the integral
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
return alpha_t * lagrange_integrals
|
||||
|
||||
|
||||
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
|
||||
"""Return a function that controls the stochasticity of SA-Solver.
|
||||
|
||||
When eta = 0, SA-Solver runs as ODE. The official approach uses
|
||||
time t to determine the SDE interval, while here we use sigma instead.
|
||||
|
||||
See:
|
||||
https://github.com/scxue/SA-Solver/blob/main/README.md
|
||||
"""
|
||||
|
||||
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
|
||||
if eta <= 0:
|
||||
return 0.0 # ODE
|
||||
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.item()
|
||||
return eta if start_sigma >= sigma >= end_sigma else 0.0
|
||||
|
||||
return tau_func
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
@@ -8,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
@@ -84,24 +86,24 @@ class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
self.cpu_tree = True
|
||||
if "cpu" in kwargs:
|
||||
self.cpu_tree = kwargs.pop("cpu")
|
||||
self.cpu_tree = kwargs.pop("cpu", True)
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
w0 = kwargs.pop('w0', None)
|
||||
if w0 is None:
|
||||
w0 = torch.zeros_like(x)
|
||||
self.batched = False
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
|
||||
elif isinstance(seed, (tuple, list)):
|
||||
if len(seed) != x.shape[0]:
|
||||
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
|
||||
self.batched = True
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
if self.cpu_tree:
|
||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||
else:
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
seed = (seed,)
|
||||
if self.cpu_tree:
|
||||
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
|
||||
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
@@ -109,11 +111,10 @@ class BatchedBrownianTree:
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
device, dtype = t0.device, t0.dtype
|
||||
if self.cpu_tree:
|
||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||
else:
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
|
||||
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
@@ -142,6 +143,43 @@ class BrownianTreeNoiseSampler:
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def sigma_to_half_log_snr(sigma, model_sampling):
|
||||
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
||||
return sigma.logit().neg()
|
||||
return sigma.log().neg()
|
||||
|
||||
|
||||
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
||||
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# 1 / (1 + exp(half_log_snr))
|
||||
return half_log_snr.neg().sigmoid()
|
||||
return half_log_snr.neg().exp()
|
||||
|
||||
|
||||
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
"""Adjust the first sigma to avoid invalid logSNR."""
|
||||
if len(sigmas) <= 1:
|
||||
return sigmas
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
if sigmas[0] >= 1:
|
||||
sigmas = sigmas.clone()
|
||||
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
||||
return sigmas
|
||||
|
||||
|
||||
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
|
||||
return torch.expm1(h)
|
||||
|
||||
|
||||
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
|
||||
return (torch.expm1(h) - h) / h
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@@ -384,9 +422,13 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@@ -682,6 +724,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
@@ -693,38 +736,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
x = x + d * dt
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
s = t + h * r
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
alpha_s = sigmas[i] * lambda_s.exp()
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
||||
s_ = t_fn(sd)
|
||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
||||
lambda_s_1_ = sd.log().neg()
|
||||
h_ = lambda_s_1_ - lambda_s
|
||||
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
||||
if eta > 0 and s_noise > 0:
|
||||
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
||||
t_next_ = t_fn(sd)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
||||
lambda_t_ = sd.log().neg()
|
||||
h_ = lambda_t_ - lambda_s
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
||||
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
||||
return x
|
||||
|
||||
|
||||
@@ -753,6 +807,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
@@ -768,9 +823,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
h, h_last = None, None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@@ -781,26 +839,34 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
if eta:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@@ -814,6 +880,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
h, h_1, h_2 = None, None, None
|
||||
|
||||
@@ -825,13 +895,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if h_2 is not None:
|
||||
# DPM-Solver++(3M) SDE
|
||||
r0 = h_1 / h
|
||||
r1 = h_2 / h
|
||||
d1_0 = (denoised - denoised_1) / r0
|
||||
@@ -840,20 +913,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
phi_3 = phi_2 / h_eta - 0.5
|
||||
x = x + phi_2 * d1 - phi_3 * d2
|
||||
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
|
||||
elif h_1 is not None:
|
||||
# DPM-Solver++(2M) SDE
|
||||
r = h_1 / h
|
||||
d = (denoised - denoised_1) / r
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
x = x + phi_2 * d
|
||||
x = x + (alpha_t * phi_2) * d
|
||||
|
||||
if eta:
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
denoised_1, denoised_2 = denoised, denoised_1
|
||||
h_1, h_2 = h, h_1
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -863,6 +938,17 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -872,6 +958,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
@@ -1009,7 +1096,9 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||
@@ -1027,6 +1116,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
@@ -1050,7 +1140,9 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
h_n = (t_next - t_cur)
|
||||
@@ -1090,6 +1182,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
@torch.no_grad()
|
||||
@@ -1140,39 +1233,22 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
"""Ancestral sampling with Euler method steps (CFG++)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
|
||||
uncond_denoised = None
|
||||
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
@@ -1181,15 +1257,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
||||
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
||||
|
||||
# DDIM stochastic sampling
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
||||
sigma_down = alpha_t * sigma_down
|
||||
|
||||
# Euler method
|
||||
x = alpha_t * denoised + sigma_down * d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Euler method steps (CFG++)."""
|
||||
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
@@ -1346,6 +1440,7 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
|
||||
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
@@ -1372,31 +1467,32 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if i == 0:
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Gradient estimation
|
||||
if cfg_pp:
|
||||
|
||||
if i >= 1:
|
||||
# Gradient estimation
|
||||
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||
else:
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@@ -1404,12 +1500,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
def default_er_sde_noise_scaler(x):
|
||||
return x * ((x ** 0.3).exp() + 10.0)
|
||||
|
||||
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
@@ -1420,129 +1522,265 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||
alpha_s = sigmas[i] / er_lambda_s
|
||||
alpha_t = sigmas[i + 1] / er_lambda_t
|
||||
r_alpha = alpha_t / alpha_s
|
||||
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
# Stage 1 Euler
|
||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
if stage_used >= 2:
|
||||
dt = er_lambda_t - er_lambda_s
|
||||
lambda_step_size = -dt / num_integration_points
|
||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||
scaled_pos = noise_scaler(lambda_pos)
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
||||
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
'''
|
||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s = t + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s = s.neg().exp()
|
||||
continue
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
'''
|
||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s_1 = t + r_1 * h
|
||||
s_2 = t + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||
continue
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
|
||||
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 2
|
||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
# Step 2
|
||||
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
|
||||
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
|
||||
if inject_noise:
|
||||
segment_factor = (r_1 - r_2) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
|
||||
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
b3 = ei_h_phi_2(-h_eta) / r_2
|
||||
b1 = ei_h_phi_1(-h_eta) - b3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
|
||||
if inject_noise:
|
||||
segment_factor = (r_2 - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
|
||||
|
||||
if tau_func is None:
|
||||
# Use default interval for stochastic sampling
|
||||
start_sigma = model_sampling.percent_to_sigma(0.2)
|
||||
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
|
||||
|
||||
max_used_order = max(predictor_order, corrector_order)
|
||||
x_pred = x # x: current state, x_pred: predicted next state
|
||||
|
||||
h = 0.0
|
||||
tau_t = 0.0
|
||||
noise = 0.0
|
||||
pred_list = []
|
||||
|
||||
# Lower order near the end to improve stability
|
||||
lower_order_to_end = sigmas[-1].item() == 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
# Evaluation
|
||||
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
pred_list.append(denoised)
|
||||
pred_list = pred_list[-max_used_order:]
|
||||
|
||||
predictor_order_used = min(predictor_order, len(pred_list))
|
||||
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
|
||||
corrector_order_used = 0
|
||||
else:
|
||||
corrector_order_used = min(corrector_order, len(pred_list))
|
||||
|
||||
if lower_order_to_end:
|
||||
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
|
||||
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
|
||||
|
||||
# Corrector
|
||||
if corrector_order_used == 0:
|
||||
# Update by the predicted state
|
||||
x = x_pred
|
||||
else:
|
||||
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i],
|
||||
curr_lambdas,
|
||||
lambdas[i - 1],
|
||||
lambdas[i],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=True,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
|
||||
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
# The noise from the previous predictor step
|
||||
x = x + noise
|
||||
|
||||
if use_pece:
|
||||
# Evaluate the corrected state
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
pred_list[-1] = denoised
|
||||
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i + 1],
|
||||
curr_lambdas,
|
||||
lambdas[i],
|
||||
lambdas[i + 1],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=False,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
|
||||
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
h = lambdas[i + 1] - lambdas[i]
|
||||
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
||||
@@ -457,11 +457,170 @@ class Wan21(LatentFormat):
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class Wan22(Wan21):
|
||||
latent_channels = 48
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.0119, 0.0103, 0.0046],
|
||||
[-0.1062, -0.0504, 0.0165],
|
||||
[ 0.0140, 0.0409, 0.0491],
|
||||
[-0.0813, -0.0677, 0.0607],
|
||||
[ 0.0656, 0.0851, 0.0808],
|
||||
[ 0.0264, 0.0463, 0.0912],
|
||||
[ 0.0295, 0.0326, 0.0590],
|
||||
[-0.0244, -0.0270, 0.0025],
|
||||
[ 0.0443, -0.0102, 0.0288],
|
||||
[-0.0465, -0.0090, -0.0205],
|
||||
[ 0.0359, 0.0236, 0.0082],
|
||||
[-0.0776, 0.0854, 0.1048],
|
||||
[ 0.0564, 0.0264, 0.0561],
|
||||
[ 0.0006, 0.0594, 0.0418],
|
||||
[-0.0319, -0.0542, -0.0637],
|
||||
[-0.0268, 0.0024, 0.0260],
|
||||
[ 0.0539, 0.0265, 0.0358],
|
||||
[-0.0359, -0.0312, -0.0287],
|
||||
[-0.0285, -0.1032, -0.1237],
|
||||
[ 0.1041, 0.0537, 0.0622],
|
||||
[-0.0086, -0.0374, -0.0051],
|
||||
[ 0.0390, 0.0670, 0.2863],
|
||||
[ 0.0069, 0.0144, 0.0082],
|
||||
[ 0.0006, -0.0167, 0.0079],
|
||||
[ 0.0313, -0.0574, -0.0232],
|
||||
[-0.1454, -0.0902, -0.0481],
|
||||
[ 0.0714, 0.0827, 0.0447],
|
||||
[-0.0304, -0.0574, -0.0196],
|
||||
[ 0.0401, 0.0384, 0.0204],
|
||||
[-0.0758, -0.0297, -0.0014],
|
||||
[ 0.0568, 0.1307, 0.1372],
|
||||
[-0.0055, -0.0310, -0.0380],
|
||||
[ 0.0239, -0.0305, 0.0325],
|
||||
[-0.0663, -0.0673, -0.0140],
|
||||
[-0.0416, -0.0047, -0.0023],
|
||||
[ 0.0166, 0.0112, -0.0093],
|
||||
[-0.0211, 0.0011, 0.0331],
|
||||
[ 0.1833, 0.1466, 0.2250],
|
||||
[-0.0368, 0.0370, 0.0295],
|
||||
[-0.3441, -0.3543, -0.2008],
|
||||
[-0.0479, -0.0489, -0.0420],
|
||||
[-0.0660, -0.0153, 0.0800],
|
||||
[-0.0101, 0.0068, 0.0156],
|
||||
[-0.0690, -0.0452, -0.0927],
|
||||
[-0.0145, 0.0041, 0.0015],
|
||||
[ 0.0421, 0.0451, 0.0373],
|
||||
[ 0.0504, -0.0483, -0.0356],
|
||||
[-0.0837, 0.0168, 0.0055]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
scale_factor = 0.75289
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.0154, -0.0397, -0.0521],
|
||||
[ 0.0005, 0.0093, 0.0006],
|
||||
[-0.0805, -0.0773, -0.0586],
|
||||
[-0.0494, -0.0487, -0.0498],
|
||||
[-0.0212, -0.0076, -0.0261],
|
||||
[-0.0179, -0.0417, -0.0505],
|
||||
[ 0.0158, 0.0310, 0.0239],
|
||||
[ 0.0409, 0.0516, 0.0201],
|
||||
[ 0.0350, 0.0553, 0.0036],
|
||||
[-0.0447, -0.0327, -0.0479],
|
||||
[-0.0038, -0.0221, -0.0365],
|
||||
[-0.0423, -0.0718, -0.0654],
|
||||
[ 0.0039, 0.0368, 0.0104],
|
||||
[ 0.0655, 0.0217, 0.0122],
|
||||
[ 0.0490, 0.1638, 0.2053],
|
||||
[ 0.0932, 0.0829, 0.0650],
|
||||
[-0.0186, -0.0209, -0.0135],
|
||||
[-0.0080, -0.0076, -0.0148],
|
||||
[-0.0284, -0.0201, 0.0011],
|
||||
[-0.0642, -0.0294, -0.0777],
|
||||
[-0.0035, 0.0076, -0.0140],
|
||||
[ 0.0519, 0.0731, 0.0887],
|
||||
[-0.0102, 0.0095, 0.0704],
|
||||
[ 0.0068, 0.0218, -0.0023],
|
||||
[-0.0726, -0.0486, -0.0519],
|
||||
[ 0.0260, 0.0295, 0.0263],
|
||||
[ 0.0250, 0.0333, 0.0341],
|
||||
[ 0.0168, -0.0120, -0.0174],
|
||||
[ 0.0226, 0.1037, 0.0114],
|
||||
[ 0.2577, 0.1906, 0.1604],
|
||||
[-0.0646, -0.0137, -0.0018],
|
||||
[-0.0112, 0.0309, 0.0358],
|
||||
[-0.0347, 0.0146, -0.0481],
|
||||
[ 0.0234, 0.0179, 0.0201],
|
||||
[ 0.0157, 0.0313, 0.0225],
|
||||
[ 0.0423, 0.0675, 0.0524],
|
||||
[-0.0031, 0.0027, -0.0255],
|
||||
[ 0.0447, 0.0555, 0.0330],
|
||||
[-0.0152, 0.0103, 0.0299],
|
||||
[-0.0755, -0.0489, -0.0635],
|
||||
[ 0.0853, 0.0788, 0.1017],
|
||||
[-0.0272, -0.0294, -0.0471],
|
||||
[ 0.0440, 0.0400, -0.0137],
|
||||
[ 0.0335, 0.0317, -0.0036],
|
||||
[-0.0344, -0.0621, -0.0984],
|
||||
[-0.0127, -0.0630, -0.0620],
|
||||
[-0.0648, 0.0360, 0.0924],
|
||||
[-0.0781, -0.0801, -0.0409],
|
||||
[ 0.0363, 0.0613, 0.0499],
|
||||
[ 0.0238, 0.0034, 0.0041],
|
||||
[-0.0135, 0.0258, 0.0310],
|
||||
[ 0.0614, 0.1086, 0.0589],
|
||||
[ 0.0428, 0.0350, 0.0205],
|
||||
[ 0.0153, 0.0173, -0.0018],
|
||||
[-0.0288, -0.0455, -0.0091],
|
||||
[ 0.0344, 0.0109, -0.0157],
|
||||
[-0.0205, -0.0247, -0.0187],
|
||||
[ 0.0487, 0.0126, 0.0064],
|
||||
[-0.0220, -0.0013, 0.0074],
|
||||
[-0.0203, -0.0094, -0.0048],
|
||||
[-0.0719, 0.0429, -0.0442],
|
||||
[ 0.1042, 0.0497, 0.0356],
|
||||
[-0.0659, -0.0578, -0.0280],
|
||||
[-0.0060, -0.0322, -0.0234]]
|
||||
|
||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||
|
||||
class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2_1(LatentFormat):
|
||||
scale_factor = 1.0039506158752403
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
@@ -470,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 1.0, 0.0, 0.0 ],
|
||||
[ 0.0, 1.0, 0.0 ],
|
||||
[ 0.0, 0.0, 1.0 ]
|
||||
]
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
@@ -133,6 +133,7 @@ class Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
@@ -140,6 +141,7 @@ class Attention(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
@@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
@@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
@@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
@@ -313,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
transformer_options={},
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
@@ -338,12 +340,34 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||
controlnet_scale, lyrics_strength, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
@@ -371,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@@ -380,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -298,7 +298,8 @@ class Attention(nn.Module):
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None
|
||||
causal = None,
|
||||
transformer_options={},
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
@@ -363,7 +364,7 @@ class Attention(nn.Module):
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = self.to_out(out)
|
||||
|
||||
if mask is not None:
|
||||
@@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
|
||||
global_cond=None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None
|
||||
rotary_pos_emb = None,
|
||||
transformer_options={}
|
||||
):
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
|
||||
@@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = x + residual
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
@@ -632,7 +635,7 @@ class ContinuousTransformer(nn.Module):
|
||||
# Attention layers
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
@@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@@ -84,7 +85,7 @@ class SingleAttention(nn.Module):
|
||||
)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c):
|
||||
def forward(self, c, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
|
||||
@@ -94,7 +95,7 @@ class SingleAttention(nn.Module):
|
||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
|
||||
@@ -143,7 +144,7 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x):
|
||||
def forward(self, c, x, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
@@ -167,7 +168,7 @@ class DoubleAttention(nn.Module):
|
||||
torch.cat([cv, xv], dim=1),
|
||||
)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||
c = self.w1o(c)
|
||||
@@ -206,7 +207,7 @@ class MMDiTBlock(nn.Module):
|
||||
self.is_last = is_last
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, global_cond, **kwargs):
|
||||
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
|
||||
|
||||
cres, xres = c, x
|
||||
|
||||
@@ -224,7 +225,7 @@ class MMDiTBlock(nn.Module):
|
||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||
|
||||
# attention
|
||||
c, x = self.attn(c, x)
|
||||
c, x = self.attn(c, x, transformer_options=transformer_options)
|
||||
|
||||
|
||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||
@@ -254,13 +255,13 @@ class DiTBlock(nn.Module):
|
||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, cx, global_cond, **kwargs):
|
||||
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
|
||||
cxres = cx
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||
global_cond
|
||||
).chunk(6, dim=1)
|
||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||
cx = self.attn(cx)
|
||||
cx = self.attn(cx, transformer_options=transformer_options)
|
||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
@@ -465,13 +473,14 @@ class MMDiT(nn.Module):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
args["vec"],
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
@@ -480,13 +489,13 @@ class MMDiT(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
def forward(self, q, k, v, transformer_options={}):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
def forward(self, x, kv, self_attn=False, transformer_options={}):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = self.attn(x, kv, kv, transformer_options=transformer_options)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
def forward(self, x, kv, transformer_options={}):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ class StageB(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
def _down_encode(self, x, r_embed, clip, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@@ -187,7 +187,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -199,7 +199,7 @@ class StageB(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@@ -216,7 +216,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -228,7 +228,7 @@ class StageB(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
@@ -245,8 +245,8 @@ class StageB(nn.Module):
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@@ -182,7 +182,7 @@ class StageC(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@@ -201,7 +201,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -213,7 +213,7 @@ class StageC(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@@ -235,7 +235,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -247,7 +247,7 @@ class StageC(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
@@ -262,8 +262,8 @@ class StageC(nn.Module):
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
@@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
@@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
@@ -150,8 +151,6 @@ class Chroma(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
@@ -192,14 +191,16 @@ class Chroma(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@@ -208,7 +209,8 @@ class Chroma(nn.Module):
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -228,17 +230,19 @@ class Chroma(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -248,19 +252,29 @@ class Chroma(nn.Module):
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
if hasattr(self, "final_layer"):
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
if img.ndim != 3 or context.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
@@ -268,4 +282,4 @@ class Chroma(nn.Module):
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
|
||||
|
||||
206
comfy/ldm/chroma_radiance/layers.py
Normal file
206
comfy/ldm/chroma_radiance/layers.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Adapted from https://github.com/lodestone-rock/flow
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
An embedder module that combines input features with a 2D positional
|
||||
encoding that mimics the Discrete Cosine Transform (DCT).
|
||||
|
||||
This module takes an input tensor of shape (B, P^2, C), where P is the
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_size_input: int,
|
||||
max_freqs: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
hidden_size_input (int): The desired dimension of the output embedding.
|
||||
max_freqs (int): The number of frequency components to use for both
|
||||
the x and y dimensions of the positional encoding.
|
||||
The total number of positional features will be max_freqs^2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
self.embedder = nn.Sequential(
|
||||
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||
|
||||
The LRU cache is a performance optimization that avoids recomputing the
|
||||
same positional grid on every forward pass.
|
||||
|
||||
Args:
|
||||
patch_size (int): The side length of the square input patch.
|
||||
device: The torch device to create the tensors on.
|
||||
dtype: The torch dtype for the tensors.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
|
||||
positional embeddings.
|
||||
"""
|
||||
# Create normalized 1D coordinate grids from 0 to 1.
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
|
||||
# Create a 2D meshgrid of coordinates.
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
|
||||
# Reshape positions to be broadcastable with frequencies.
|
||||
# Shape becomes (patch_size^2, 1, 1).
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
|
||||
# Reshape frequencies to be broadcastable for creating 2D basis functions.
|
||||
# freqs_x shape: (1, max_freqs, 1)
|
||||
# freqs_y shape: (1, 1, max_freqs)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
|
||||
# A custom weighting coefficient, not part of standard DCT.
|
||||
# This seems to down-weight the contribution of higher-frequency interactions.
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
|
||||
# Calculate the 1D cosine basis functions for x and y coordinates.
|
||||
# This is the core of the DCT formulation.
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
|
||||
# Combine the 1D basis functions to create 2D basis functions by element-wise
|
||||
# multiplication, and apply the custom coefficients. Broadcasting handles the
|
||||
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
|
||||
# The result is flattened into a feature vector for each position.
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor of shape (B, P^2, C).
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
|
||||
"""
|
||||
# Get the batch size, number of pixels, and number of channels.
|
||||
B, P2, C = inputs.shape
|
||||
|
||||
# Infer the patch side length from the number of pixels (P^2).
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
input_dtype = inputs.dtype
|
||||
inputs = inputs.to(dtype=self.dtype)
|
||||
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
|
||||
# Concatenate the original input features with the positional embeddings
|
||||
# along the feature dimension.
|
||||
inputs = torch.cat((inputs, dct), dim=-1)
|
||||
|
||||
# Project the combined tensor to the target hidden size.
|
||||
return self.embedder(inputs).to(dtype=input_dtype)
|
||||
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
"""
|
||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||
"""
|
||||
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
# The total number of parameters for the MLP is increased to accommodate
|
||||
# the gate, value, and output projection matrices.
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_x, hidden_size_x = x.shape
|
||||
mlp_params = self.param_generator(s)
|
||||
|
||||
# Split the generated parameters into three parts for the gate, value, and output projection.
|
||||
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
|
||||
|
||||
# Reshape the parameters into matrices for batch matrix multiplication.
|
||||
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
|
||||
|
||||
# Normalize the generated weight matrices as in the original implementation.
|
||||
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
|
||||
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
|
||||
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
|
||||
|
||||
res_x = x
|
||||
x = self.norm(x)
|
||||
|
||||
# Apply the final output projection.
|
||||
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||
|
||||
return x + res_x
|
||||
|
||||
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
|
||||
|
||||
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
|
||||
329
comfy/ldm/chroma_radiance/model.py
Normal file
329
comfy/ldm/chroma_radiance/model.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Credits:
|
||||
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
|
||||
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
NerfEmbedder,
|
||||
NerfGLUBlock,
|
||||
NerfFinalLayer,
|
||||
NerfFinalLayerConv,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaRadianceParams(ChromaParams):
|
||||
patch_size: int
|
||||
nerf_hidden_size: int
|
||||
nerf_mlp_ratio: int
|
||||
nerf_depth: int
|
||||
nerf_max_freqs: int
|
||||
# Setting nerf_tile_size to 0 disables tiling.
|
||||
nerf_tile_size: int
|
||||
# Currently one of linear (legacy) or conv.
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
if operations is None:
|
||||
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
params = ChromaRadianceParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.in_dim = params.in_dim
|
||||
self.out_dim = params.out_dim
|
||||
self.hidden_dim = params.hidden_dim
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in_patch = operations.Conv2d(
|
||||
params.in_channels,
|
||||
params.hidden_size,
|
||||
kernel_size=params.patch_size,
|
||||
stride=params.patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
in_dim=self.in_dim,
|
||||
hidden_dim=self.hidden_dim,
|
||||
out_dim=self.out_dim,
|
||||
n_layers=self.n_layers,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# pixel channel concat with DCT
|
||||
self.nerf_image_embedder = NerfEmbedder(
|
||||
in_channels=params.in_channels,
|
||||
hidden_size_input=params.nerf_hidden_size,
|
||||
max_freqs=params.nerf_max_freqs,
|
||||
dtype=params.nerf_embedder_dtype or dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.nerf_blocks = nn.ModuleList([
|
||||
NerfGLUBlock(
|
||||
hidden_size_s=params.hidden_size,
|
||||
hidden_size_x=params.nerf_hidden_size,
|
||||
mlp_ratio=params.nerf_mlp_ratio,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
) for _ in range(params.nerf_depth)
|
||||
])
|
||||
|
||||
if params.nerf_final_head_type == "linear":
|
||||
self.nerf_final_layer = NerfFinalLayer(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
elif params.nerf_final_head_type == "conv":
|
||||
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
else:
|
||||
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||
raise ValueError(errstr)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
return self.nerf_final_layer
|
||||
if self.params.nerf_final_head_type == "conv":
|
||||
return self.nerf_final_layer_conv
|
||||
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||
raise NotImplementedError
|
||||
|
||||
def img_in(self, img: Tensor) -> Tensor:
|
||||
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||
# flatten into a sequence for the transformer.
|
||||
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||
|
||||
def forward_nerf(
|
||||
self,
|
||||
img_orig: Tensor,
|
||||
img_out: Tensor,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
B, C, H, W = img_orig.shape
|
||||
num_patches = img_out.shape[1]
|
||||
patch_size = params.patch_size
|
||||
|
||||
# Store the raw pixel values of each patch for the NeRF head later.
|
||||
# unfold creates patches: [B, C * P * P, NumPatches]
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
# Pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct = block(img_dct, nerf_hidden)
|
||||
|
||||
# Reassemble the patches into the final image.
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
||||
# Reshape to combine with batch dimension for fold
|
||||
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
||||
img_dct = nn.functional.fold(
|
||||
img_dct,
|
||||
output_size=(H, W),
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
return self._nerf_final_layer(img_dct)
|
||||
|
||||
def forward_tiled_nerf(
|
||||
self,
|
||||
nerf_hidden: Tensor,
|
||||
nerf_pixels: Tensor,
|
||||
batch: int,
|
||||
channels: int,
|
||||
num_patches: int,
|
||||
patch_size: int,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Processes the NeRF head in tiles to save memory.
|
||||
nerf_hidden has shape [B, L, D]
|
||||
nerf_pixels has shape [B, L, C * P * P]
|
||||
"""
|
||||
tile_size = params.nerf_tile_size
|
||||
output_tiles = []
|
||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||
for i in range(0, num_patches, tile_size):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||
|
||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||
|
||||
output_tiles.append(img_dct_tile)
|
||||
|
||||
# Concatenate the processed tiles along the patch dimension
|
||||
return torch.cat(output_tiles, dim=0)
|
||||
|
||||
def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
|
||||
params = self.params
|
||||
if not overrides:
|
||||
return params
|
||||
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
|
||||
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
||||
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
||||
if bad_keys:
|
||||
e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
# At this point it's all valid keys and values so we can merge with the existing params.
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
timestep: Tensor,
|
||||
context: Tensor,
|
||||
guidance: Optional[Tensor],
|
||||
control: Optional[dict]=None,
|
||||
transformer_options: dict={},
|
||||
**kwargs: dict,
|
||||
) -> Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
if img.ndim != 4:
|
||||
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
||||
if context.ndim != 3:
|
||||
raise ValueError("Input txt tensors must have 3 dimensions.")
|
||||
|
||||
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
|
||||
|
||||
h_len = (img.shape[-2] // self.patch_size)
|
||||
w_len = (img.shape[-1] // self.patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
img_ids,
|
||||
context,
|
||||
txt_ids,
|
||||
timestep,
|
||||
guidance,
|
||||
control,
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
@@ -176,6 +176,7 @@ class Attention(nn.Module):
|
||||
context=None,
|
||||
mask=None,
|
||||
rope_emb=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -184,7 +185,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
@@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for video attention.
|
||||
@@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
|
||||
context_M_B_D,
|
||||
crossattn_mask,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||
return x_T_H_W_B_D
|
||||
@@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||
@@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=None,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
elif self.block_type in ["cross_attn", "ca"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
@@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
|
||||
context=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||
@@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
@@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
return x
|
||||
|
||||
@@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
# x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -27,6 +27,8 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask,
|
||||
fps,
|
||||
image_size,
|
||||
padding_mask,
|
||||
scalar_feature,
|
||||
data_type,
|
||||
latent_condition,
|
||||
latent_condition_sigma,
|
||||
condition_video_augment_sigma,
|
||||
**kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -482,6 +520,7 @@ class GeneralDIT(nn.Module):
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
for _, block in self.blocks.items():
|
||||
assert (
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
@@ -496,6 +535,7 @@ class GeneralDIT(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
||||
@@ -11,6 +11,7 @@ import math
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
@@ -43,7 +44,7 @@ class GPT2FeedForward(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
@@ -70,11 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
result_B_S_HD = rearrange(
|
||||
optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)"
|
||||
)
|
||||
|
||||
return result_B_S_HD
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -183,8 +180,8 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
@@ -192,6 +189,7 @@ class Attention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -199,7 +197,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
@@ -462,6 +460,7 @@ class Block(nn.Module):
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
@@ -515,6 +514,7 @@ class Block(nn.Module):
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@@ -528,6 +528,7 @@ class Block(nn.Module):
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
@@ -537,6 +538,7 @@ class Block(nn.Module):
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@@ -550,6 +552,7 @@ class Block(nn.Module):
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
@@ -809,7 +812,21 @@ class MiniTrainDIT(nn.Module):
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
@@ -854,6 +871,7 @@ class MiniTrainDIT(nn.Module):
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
|
||||
@@ -123,6 +123,8 @@ class ControlNetFlux(Flux):
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
else:
|
||||
y = y[:, :self.params.vec_in_dim]
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
@@ -118,7 +118,7 @@ class Modulation(nn.Module):
|
||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
if modulation_dims is None:
|
||||
if m_add is not None:
|
||||
return tensor * m_mult + m_add
|
||||
return torch.addcmul(m_add, tensor, m_mult)
|
||||
else:
|
||||
return tensor * m_mult
|
||||
else:
|
||||
@@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
@@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
@@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
@@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
|
||||
@@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
@@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -105,6 +106,7 @@ class Flux(nn.Module):
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -116,9 +118,17 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@@ -134,14 +144,16 @@ class Flux(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@@ -150,14 +162,15 @@ class Flux(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
img[:, :add.shape[1]] += add
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -171,44 +184,97 @@ class Flux(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method == "index":
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
index = 0
|
||||
h_offset = h_len * patch_size + h
|
||||
w_offset = w_len * patch_size + w
|
||||
h += ref.shape[-2]
|
||||
w += ref.shape[-1]
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
transformer_options={},
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
transformer_options={},
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
transformer_options=transformer_options,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
crop_y=args["num_tokens"],
|
||||
transformer_options=args["transformer_options"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
transformer_options=transformer_options,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@@ -71,8 +72,8 @@ class TimestepEmbed(nn.Module):
|
||||
return t_emb
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
@@ -85,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
@@ -132,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value)
|
||||
hidden_states = attention(query, key, value, transformer_options=transformer_options)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
@@ -198,6 +200,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
@@ -205,6 +208,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -405,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
@@ -418,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
@@ -482,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
@@ -499,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
@@ -549,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
@@ -556,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -692,7 +701,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
@@ -769,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
@@ -792,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
@@ -91,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@@ -107,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
@@ -118,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
@@ -4,81 +4,458 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
# manually create the pointer vector
|
||||
assert src.size(0) == batch.numel()
|
||||
|
||||
return xyz, grid_size, length
|
||||
batch_size = int(batch.max()) + 1
|
||||
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||
|
||||
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||
|
||||
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||
|
||||
#return fps_sampling(src, ptr_vec, ratio)
|
||||
sampled_indicies = []
|
||||
|
||||
for b in range(batch_size):
|
||||
# start and the end of each batch
|
||||
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
|
||||
# points from the point cloud
|
||||
points = src[start:end]
|
||||
|
||||
num_points = points.size(0)
|
||||
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||
|
||||
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||
|
||||
# select a random start point
|
||||
if start_random:
|
||||
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||
else:
|
||||
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||
|
||||
for i in range(num_samples):
|
||||
selected[i] = farthest
|
||||
centroid = points[farthest].squeeze(0)
|
||||
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||
distances = torch.minimum(distances, dist)
|
||||
farthest = torch.argmax(distances)
|
||||
|
||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||
|
||||
return torch.cat(sampled_indicies, dim = 0)
|
||||
class PointCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
num_latents: int,
|
||||
downsample_ratio: float,
|
||||
pc_size: int,
|
||||
pc_sharpedge_size: int,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
fourier_embedder,
|
||||
normal_pe: bool = False,
|
||||
qkv_bias: bool = False,
|
||||
use_ln_post: bool = True,
|
||||
qk_norm: bool = True):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.normal_pe = normal_pe
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.num_latents = num_latents
|
||||
self.point_feats = point_feats
|
||||
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm
|
||||
)
|
||||
|
||||
self.self_attn = None
|
||||
if layers > 0:
|
||||
self.self_attn = Transformer(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm,
|
||||
layers = layers
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
"""
|
||||
Subsample points randomly from the point cloud (input_pc)
|
||||
Further sample the subsampled points to get query_pc
|
||||
take the fourier embeddings for both input and query pc
|
||||
|
||||
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
|
||||
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
|
||||
More computationally efficient.
|
||||
|
||||
Features are additional information for each point in the cloud
|
||||
"""
|
||||
|
||||
B, _, D = point_cloud.shape
|
||||
|
||||
num_latents = int(self.num_latents)
|
||||
|
||||
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||
num_sharpedge_query = num_latents - num_random_query
|
||||
|
||||
# Split random and sharpedge surface points
|
||||
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||
|
||||
# assert statements
|
||||
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||
|
||||
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||
|
||||
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||
|
||||
else:
|
||||
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
# concat the random and sharpedges
|
||||
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||
|
||||
query = self.fourier_embedder(query_pc)
|
||||
data = self.fourier_embedder(input_pc)
|
||||
|
||||
if self.point_feats > 0:
|
||||
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||
|
||||
input_random_surface_features, query_random_features = \
|
||||
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||
|
||||
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||
else:
|
||||
|
||||
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||
|
||||
if self.normal_pe:
|
||||
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||
# replace the first 3 dims with the new PE ones
|
||||
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||
|
||||
# concat at the channels dim
|
||||
query = torch.cat([query, query_features], dim = -1)
|
||||
data = torch.cat([data, input_features], dim = -1)
|
||||
|
||||
# don't return pc_info to avoid unnecessary memory usuage
|
||||
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||
|
||||
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||
|
||||
# apply projections
|
||||
query = self.input_proj(query)
|
||||
data = self.input_proj(data)
|
||||
|
||||
# apply cross attention between query and data
|
||||
latents = self.cross_attn(query, data)
|
||||
|
||||
if self.self_attn is not None:
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
def subsample(self, pc, num_query, input_pc_size: int):
|
||||
|
||||
"""
|
||||
num_query: number of points to keep after FPS
|
||||
input_pc_size: number of points to select before FPS
|
||||
"""
|
||||
|
||||
B, _, D = pc.shape
|
||||
query_ratio = num_query / input_pc_size
|
||||
|
||||
# random subsampling of points inside the point cloud
|
||||
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||
input_pc = pc[:, idx_pc, :]
|
||||
|
||||
# flatten to allow applying fps across the whole batch
|
||||
flattent_input_pc = input_pc.view(B * input_pc_size, D)
|
||||
|
||||
# construct a batch_down tensor to tell fps
|
||||
# which points belong to which batch
|
||||
N_down = int(flattent_input_pc.shape[0] / B)
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
|
||||
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||
|
||||
return query_pc, input_pc, idx_pc, idx_query
|
||||
|
||||
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
|
||||
|
||||
B = batch_size
|
||||
|
||||
input_surface_features = features[:, idx_pc, :]
|
||||
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
|
||||
query_features = flattent_input_features[idx_query].view(B, -1,
|
||||
flattent_input_features.shape[-1])
|
||||
|
||||
return input_surface_features, query_features
|
||||
|
||||
def normalize_mesh(mesh, scale = 0.9999):
|
||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||
|
||||
bbox = mesh.bounds
|
||||
center = (bbox[1] + bbox[0]) / 2
|
||||
|
||||
max_extent = (bbox[1] - bbox[0]).max()
|
||||
mesh.apply_translation(-center)
|
||||
mesh.apply_scale((2 * scale) / max_extent)
|
||||
|
||||
return mesh
|
||||
|
||||
def sample_pointcloud(mesh, num = 200000):
|
||||
""" Uniformly sample points from the surface of the mesh """
|
||||
|
||||
points, face_idx = mesh.sample(num, return_index = True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||
|
||||
def detect_sharp_edges(mesh, threshold=0.985):
|
||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||
|
||||
V, F = mesh.vertices, mesh.faces
|
||||
VN, FN = mesh.vertex_normals, mesh.face_normals
|
||||
|
||||
sharp_mask = np.ones(V.shape[0])
|
||||
for i in range(3):
|
||||
indices = F[:, i]
|
||||
alignment = np.einsum('ij,ij->i', VN[indices], FN)
|
||||
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
|
||||
sharp_mask[indices] = np.min(dot_stack, axis=-1)
|
||||
|
||||
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
||||
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
||||
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
|
||||
|
||||
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||
|
||||
|
||||
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
""" Sample points preferentially from sharp edges in the mesh. """
|
||||
|
||||
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||
V, VN = mesh.vertices, mesh.vertex_normals
|
||||
|
||||
va, vb = V[edge_a], V[edge_b]
|
||||
na, nb = VN[edge_a], VN[edge_b]
|
||||
|
||||
edge_lengths = np.linalg.norm(vb - va, axis=-1)
|
||||
weights = edge_lengths / edge_lengths.sum()
|
||||
|
||||
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
|
||||
t = np.random.rand(num, 1)
|
||||
|
||||
samples = t * va[indices] + (1 - t) * vb[indices]
|
||||
normals = t * na[indices] + (1 - t) * nb[indices]
|
||||
|
||||
return samples.astype(np.float32), normals.astype(np.float32)
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||
|
||||
import trimesh
|
||||
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
except Exception:
|
||||
mesh_full = trimesh.util.concatenate(mesh)
|
||||
|
||||
mesh_full = normalize_mesh(mesh_full)
|
||||
|
||||
faces = mesh_full.faces
|
||||
vertices = mesh_full.vertices
|
||||
origin_face_count = faces.shape[0]
|
||||
|
||||
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
|
||||
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
|
||||
|
||||
area_surface = mesh_surface.area
|
||||
area_fill = mesh_fill.area
|
||||
total_area = area_surface + area_fill
|
||||
|
||||
sample_num = 499712 // 2
|
||||
fill_ratio = area_fill / total_area if total_area > 0 else 0
|
||||
|
||||
num_fill = int(sample_num * fill_ratio)
|
||||
num_surface = sample_num - num_fill
|
||||
|
||||
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
|
||||
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
|
||||
|
||||
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
|
||||
|
||||
def assemble_tensor(points, normals, label=None):
|
||||
|
||||
data = torch.cat([points, normals], dim=1).half().to(device)
|
||||
|
||||
if label is not None:
|
||||
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
|
||||
data = torch.cat([data, label_tensor], dim=1)
|
||||
|
||||
return data
|
||||
|
||||
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||
label = 0 if sharpedge_flag else None)
|
||||
|
||||
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||
label = 1 if sharpedge_flag else None)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||
|
||||
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||
|
||||
return full
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
""" Load mesh surface and sharp edge samples. """
|
||||
|
||||
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||
|
||||
self.num_uniform_points = num_uniform_points
|
||||
self.num_sharp_points = num_sharp_points
|
||||
self.total_points = num_uniform_points + num_sharp_points
|
||||
|
||||
def __call__(self, mesh_input, device = "cuda"):
|
||||
mesh = self._load_mesh(mesh_input)
|
||||
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||
|
||||
@staticmethod
|
||||
def _load_mesh(mesh_input):
|
||||
import trimesh
|
||||
|
||||
if isinstance(mesh_input, str):
|
||||
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||
else:
|
||||
mesh = mesh_input
|
||||
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
combined = None
|
||||
for obj in mesh.geometry.values():
|
||||
combined = obj if combined is None else combined + obj
|
||||
return combined
|
||||
|
||||
return mesh
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||
|
||||
# divide quant channels (8) into mean and log variance
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
||||
def sample(self):
|
||||
|
||||
eps = torch.randn_like(self.std)
|
||||
z = self.mean + eps * self.std
|
||||
|
||||
return z
|
||||
|
||||
################################################
|
||||
# Volume Decoder
|
||||
################################################
|
||||
|
||||
class VanillaVolumeDecoder():
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||
|
||||
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
|
||||
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
|
||||
chunk_queries = xyz[start: start + num_chunks, :]
|
||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
@@ -232,38 +607,41 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_data = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
|
||||
return out
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
if not self.enable_ln_post:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
self,
|
||||
*,
|
||||
num_latents: int = 4096,
|
||||
embed_dim: int = 64,
|
||||
width: int = 1024,
|
||||
heads: int = 16,
|
||||
num_decoder_layers: int = 16,
|
||||
num_encoder_layers: int = 8,
|
||||
pc_size: int = 81920,
|
||||
pc_sharpedge_size: int = 0,
|
||||
point_feats: int = 4,
|
||||
downsample_ratio: int = 20,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
drop_path_rate: float = 0.0,
|
||||
include_pi: bool = False,
|
||||
scale_factor: float = 1.0039506158752403,
|
||||
label_type: str = "binary",
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||
num_latents = num_latents,
|
||||
downsample_ratio = downsample_ratio,
|
||||
heads = heads,
|
||||
pc_size = pc_size,
|
||||
width = width,
|
||||
point_feats = point_feats,
|
||||
fourier_embedder = self.fourier_embedder,
|
||||
pc_sharpedge_size = pc_sharpedge_size)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
||||
def encode(self, surface):
|
||||
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents = self.encoder(pc, feats)
|
||||
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||
|
||||
latents = posterior.sample()
|
||||
|
||||
return latents
|
||||
|
||||
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
@@ -0,0 +1,659 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if gate.device.type == "mps":
|
||||
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
|
||||
|
||||
return F.gelu(gate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, dim_out = None, mult: int = 4,
|
||||
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(act_fn)
|
||||
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class AddAuxLoss(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
# do nothing in forward (no computation)
|
||||
ctx.requires_aux_loss = loss.requires_grad
|
||||
ctx.dtype = loss.dtype
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# add the aux loss gradients
|
||||
grad_loss = None
|
||||
# put the aux grad the same as the main grad loss
|
||||
# aux grad contributes equally
|
||||
if ctx.requires_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
|
||||
|
||||
return grad_output, grad_loss
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
self.top_k = num_experts_per_tok
|
||||
self.n_routed_experts = num_experts
|
||||
|
||||
self.alpha = aux_loss_alpha
|
||||
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# flatten hidden states
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
|
||||
# get logits and pass it to softmax
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
|
||||
scores = logits.softmax(dim = -1)
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
|
||||
# used bincount instead of one hot encoding
|
||||
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
|
||||
ce = counts / topk_idx.numel() # normalized expert usage
|
||||
|
||||
# mean expert score
|
||||
Pi = scores_for_aux.mean(0)
|
||||
|
||||
# expert balance loss
|
||||
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
class MoEBlock(nn.Module):
|
||||
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
|
||||
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.moe_top_k = moe_top_k
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
|
||||
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
|
||||
if self.training:
|
||||
|
||||
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
|
||||
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
|
||||
|
||||
for i, expert in enumerate(self.experts):
|
||||
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
|
||||
y = y.view(*orig_shape)
|
||||
|
||||
y = AddAuxLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
|
||||
y = y + self.shared_experts(identity)
|
||||
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
|
||||
# no need for .numpy().cpu() here
|
||||
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
|
||||
token_idxs = idxs // self.moe_top_k
|
||||
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
|
||||
# + avoid dtype conversion
|
||||
expert_cache.index_add_(0, exp_token_idx, expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
|
||||
scale: float = 1.0, max_period: int = 10000):
|
||||
super().__init__()
|
||||
|
||||
self.num_channels = num_channels
|
||||
half_dim = num_channels // 2
|
||||
|
||||
# precompute the “inv_freq” vector once
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
half_dim, dtype=torch.float32
|
||||
) / (half_dim - downscale_freq_shift)
|
||||
|
||||
inv_freq = torch.exp(exponent)
|
||||
|
||||
# pad
|
||||
if num_channels % 2 == 1:
|
||||
# we’ll pad a zero at the end of the cos-half
|
||||
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
|
||||
|
||||
# register to buffer so it moves with the device
|
||||
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps: torch.Tensor):
|
||||
|
||||
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
|
||||
|
||||
|
||||
# fused CUDA kernels for sin and cos
|
||||
sin_emb = x.sin()
|
||||
cos_emb = x.cos()
|
||||
|
||||
emb = torch.cat([sin_emb, cos_emb], dim = 1)
|
||||
|
||||
# scale factor
|
||||
if self.scale != 1.0:
|
||||
emb = emb * self.scale
|
||||
|
||||
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
|
||||
if emb.shape[1] > self.num_channels:
|
||||
emb = emb[:, :self.num_channels]
|
||||
|
||||
return emb
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
|
||||
|
||||
self.time_embed = Timesteps(hidden_size)
|
||||
|
||||
def forward(self, timesteps, condition):
|
||||
|
||||
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
|
||||
|
||||
if condition is not None:
|
||||
cond_embed = self.cond_proj(condition)
|
||||
timestep_embed = timestep_embed + cond_embed
|
||||
|
||||
time_conditioned = self.mlp(timestep_embed)
|
||||
|
||||
# for broadcasting with image tokens
|
||||
return time_conditioned.unsqueeze(1)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
|
||||
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.gelu(self.fc1(x)))
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.qdim // num_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, y):
|
||||
|
||||
b, s1, _ = x.shape
|
||||
_, s2, _ = y.shape
|
||||
|
||||
y = y.to(next(self.to_k.parameters()).dtype)
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(y)
|
||||
v = self.to_v(y)
|
||||
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s2, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, s2, self.num_heads * self.head_dim)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
x = optimized_attention(
|
||||
q.reshape(b, s1, self.num_heads * self.head_dim),
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
|
||||
return out
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias = True,
|
||||
qk_norm = False,
|
||||
norm_layer = nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None,
|
||||
dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||
|
||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||
key = key.reshape(B, N, self.num_heads, self.head_dim)
|
||||
value = value.reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
query = self.q_norm(query)
|
||||
key = self.k_norm(key)
|
||||
|
||||
x = optimized_attention(
|
||||
query.reshape(B, N, self.num_heads * self.head_dim),
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm_layer=True,
|
||||
qkv_bias=True,
|
||||
skip_connection=True,
|
||||
timested_modulate=False,
|
||||
use_moe: bool = False,
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None, dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.timested_modulate = timested_modulate
|
||||
if self.timested_modulate:
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
if skip_connection:
|
||||
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.use_moe = use_moe
|
||||
|
||||
if self.use_moe:
|
||||
self.moe = MoEBlock(
|
||||
hidden_size,
|
||||
num_experts = num_experts,
|
||||
moe_top_k = moe_top_k,
|
||||
dropout = 0.0,
|
||||
ff_inner_dim = int(hidden_size * 4.0),
|
||||
device = device, dtype = dtype,
|
||||
operations = operations
|
||||
)
|
||||
else:
|
||||
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
|
||||
|
||||
if self.skip_linear is not None:
|
||||
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
|
||||
hidden_states = self.skip_linear(combined)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
|
||||
# self attention
|
||||
if self.timested_modulate:
|
||||
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
|
||||
hidden_states = hidden_states + modulation_shift
|
||||
|
||||
self_attn_out = self.attn1(self.norm1(hidden_states))
|
||||
hidden_states = hidden_states + self_attn_out
|
||||
|
||||
# cross attention
|
||||
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
|
||||
|
||||
# MLP Layer
|
||||
mlp_input = self.norm3(hidden_states)
|
||||
|
||||
if self.use_moe:
|
||||
hidden_states = hidden_states + self.moe(mlp_input)
|
||||
else:
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
|
||||
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm_final(x)
|
||||
x = x[:, 1:]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTPlain(nn.Module):
|
||||
|
||||
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
hidden_size: int = 2048,
|
||||
context_dim: int = 1024,
|
||||
depth: int = 21,
|
||||
num_heads: int = 16,
|
||||
qk_norm: bool = True,
|
||||
qkv_bias: bool = False,
|
||||
num_moe_layers: int = 6,
|
||||
guidance_cond_proj_dim = 2048,
|
||||
norm_type = 'layer',
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
|
||||
qk_norm = operations.RMSNorm
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||
|
||||
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
text_states_dim=context_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_layer = norm,
|
||||
qk_norm_layer = qk_norm,
|
||||
skip_connection=layer > depth // 2,
|
||||
qkv_bias=qkv_bias,
|
||||
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||
num_experts=num_experts,
|
||||
moe_top_k=moe_top_k,
|
||||
use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
|
||||
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
|
||||
|
||||
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
|
||||
x_embedded = self.x_embedder(x)
|
||||
|
||||
combined = torch.cat([time_embedded, x_embedded], dim=1)
|
||||
|
||||
def block_wrap(args):
|
||||
return block(
|
||||
args["x"],
|
||||
args["t"],
|
||||
args["cond"],
|
||||
skip_tensor=args.get("skip"),)
|
||||
|
||||
skip_stack = []
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for idx, block in enumerate(self.blocks):
|
||||
if idx <= self.depth // 2:
|
||||
skip_input = None
|
||||
else:
|
||||
skip_input = skip_stack.pop()
|
||||
|
||||
if ("block", idx) in blocks_replace:
|
||||
|
||||
combined = blocks_replace[("block", idx)](
|
||||
{
|
||||
"x": combined,
|
||||
"t": time_embedded,
|
||||
"cond": main_condition,
|
||||
"skip": skip_input,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
else:
|
||||
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
|
||||
|
||||
if idx < self.depth // 2:
|
||||
skip_stack.append(combined)
|
||||
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
@@ -1,6 +1,7 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@@ -39,6 +40,8 @@ class HunyuanVideoParams:
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@@ -77,13 +80,13 @@ class TokenRefinerBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
@@ -114,14 +117,14 @@ class IndividualTokenRefiner(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m)
|
||||
x = block(x, c, m, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@@ -149,6 +152,7 @@ class TokenRefiner(nn.Module):
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
transformer_options={},
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
@@ -157,9 +161,33 @@ class TokenRefiner(nn.Module):
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
class ByT5Mapper(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
|
||||
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
|
||||
self.use_res = use_res
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res:
|
||||
res = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x2 = self.act_fn(x)
|
||||
x2 = self.fc3(x2)
|
||||
if self.use_res:
|
||||
x2 = x2 + res
|
||||
return x2
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@@ -184,9 +212,13 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
@@ -214,6 +246,23 @@ class HunyuanVideo(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
if params.byt5:
|
||||
self.byt5_in = ByT5Mapper(
|
||||
in_dim=1472,
|
||||
out_dim=2048,
|
||||
hidden_dim=2048,
|
||||
out_dim1=self.hidden_size,
|
||||
use_res=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.byt5_in = None
|
||||
|
||||
if params.meanflow:
|
||||
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.time_r_in = None
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@@ -225,10 +274,12 @@ class HunyuanVideo(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
disable_time_r=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -239,6 +290,14 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if (self.time_r_in is not None) and (not disable_time_r):
|
||||
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
||||
if len(w) > 0:
|
||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||
vec = (vec + vec_r) / 2
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
@@ -249,13 +308,17 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
if self.vector_in is not None:
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
else:
|
||||
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
@@ -266,7 +329,13 @@ class HunyuanVideo(nn.Module):
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@@ -284,14 +353,14 @@ class HunyuanVideo(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -306,13 +375,13 @@ class HunyuanVideo(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -327,12 +396,16 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
shape = initial_shape[-len(self.patch_size):]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
if img.ndim == 8:
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
else:
|
||||
img = img.permute(0, 3, 1, 4, 2, 5)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
|
||||
return img
|
||||
|
||||
def img_ids(self, x):
|
||||
@@ -347,9 +420,30 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
def img_ids_2d(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
|
||||
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
|
||||
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
136
comfy/ldm/hunyuan_video/vae.py
Normal file
136
comfy/ldm/hunyuan_video/vae.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class PixelShuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
|
||||
self.ratio = (in_dim << 2) // out_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h >> 1, w >> 1
|
||||
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
|
||||
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
|
||||
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
|
||||
|
||||
|
||||
class PixelUnshuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
|
||||
self.scale = (out_dim << 2) // in_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h << 1, w << 1
|
||||
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
return y + r
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, h, w).mean(2)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, z):
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
267
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
267
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
shape = (dim, 1, 1, 1)
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds=True):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
|
||||
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = self.conv(x)
|
||||
|
||||
if self.tds:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nf = frms // r1
|
||||
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
b, ci, frms, ht, wd = xn.shape
|
||||
nf = frms // r1
|
||||
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xn.shape
|
||||
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = sc.shape
|
||||
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
return h + sc
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus=True):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
|
||||
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = self.conv(x)
|
||||
|
||||
if self.tus:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
||||
b, c, f, ht, wd = xf.shape
|
||||
nc = c // (2 * 2)
|
||||
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = xn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = sc.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
return h + sc
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
out = self.regul(out)[0]
|
||||
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, 3)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, 3)
|
||||
|
||||
def forward(self, z):
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
@@ -261,8 +262,8 @@ class CrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
@@ -270,7 +271,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@@ -284,9 +285,9 @@ class CrossAttention(nn.Module):
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -302,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
@@ -471,10 +479,10 @@ class LTXVModel(torch.nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -482,7 +490,8 @@ class LTXVModel(torch.nn.Module):
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
|
||||
@@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=self.timestep_conditioning,
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
@@ -103,6 +104,7 @@ class JointAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -139,7 +141,7 @@ class JointAttention(nn.Module):
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
return self.out(output)
|
||||
|
||||
@@ -267,6 +269,7 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the TransformerBlock.
|
||||
@@ -289,6 +292,7 @@ class JointTransformerBlock(nn.Module):
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
@@ -303,6 +307,7 @@ class JointTransformerBlock(nn.Module):
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
@@ -493,7 +498,7 @@ class NextDiT(nn.Module):
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
@@ -553,7 +558,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
@@ -572,7 +577,7 @@ class NextDiT(nn.Module):
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
@@ -590,8 +595,15 @@ class NextDiT(nn.Module):
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@@ -608,12 +620,13 @@ class NextDiT(nn.Module):
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input)
|
||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
|
||||
@@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = True):
|
||||
def __init__(self, sample: bool = False):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
|
||||
@@ -19,17 +19,19 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
yield from ()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
log = dict()
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if self.sample:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
log["kl_loss"] = kl_loss
|
||||
return z, log
|
||||
return z, None
|
||||
|
||||
class EmptyRegularizer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
return z, None
|
||||
|
||||
class AbstractAutoencoder(torch.nn.Module):
|
||||
"""
|
||||
|
||||
@@ -5,8 +5,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Callable, Union
|
||||
import logging
|
||||
import functools
|
||||
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
@@ -17,23 +18,45 @@ if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError as e:
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
else:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.flash_attention_enabled():
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
if name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
REGISTERED_ATTENTION_FUNCTIONS[name] = func
|
||||
else:
|
||||
logging.warning(f"Attention function {name} already registered, skipping registration.")
|
||||
|
||||
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
|
||||
if name == "optimized":
|
||||
return optimized_attention
|
||||
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
if default is ...:
|
||||
raise KeyError(f"Attention function {name} not found.")
|
||||
else:
|
||||
return default
|
||||
return REGISTERED_ATTENTION_FUNCTIONS[name]
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -91,7 +114,27 @@ class FeedForward(nn.Module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
|
||||
def wrap_attn(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
remove_attn_wrapper_key = False
|
||||
try:
|
||||
if "_inside_attn_wrapper" not in kwargs:
|
||||
transformer_options = kwargs.get("transformer_options", None)
|
||||
remove_attn_wrapper_key = True
|
||||
kwargs["_inside_attn_wrapper"] = True
|
||||
if transformer_options is not None:
|
||||
if "optimized_attention_override" in transformer_options:
|
||||
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
if remove_attn_wrapper_key:
|
||||
del kwargs["_inside_attn_wrapper"]
|
||||
return wrapper
|
||||
|
||||
@wrap_attn
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -359,7 +403,8 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
@@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
|
||||
if skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
@@ -427,8 +472,8 @@ else:
|
||||
#TODO: other GPUs ?
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@@ -448,7 +493,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@@ -461,7 +506,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.shape[0] > 1:
|
||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||
q[i : i + SDP_BATCH_LIMIT],
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
@@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
@@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
@@ -534,8 +579,8 @@ except AttributeError as error:
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@@ -555,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
assert mask is None
|
||||
if mask is not None:
|
||||
raise RuntimeError("Mask must not be set for Flash attention")
|
||||
out = flash_attn_wrapper(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
@@ -597,6 +643,19 @@ else:
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
register_attention_function("xformers", attention_xformers)
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
@@ -629,7 +688,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
@@ -640,9 +699,9 @@ class CrossAttention(nn.Module):
|
||||
v = self.to_v(context)
|
||||
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -746,7 +805,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||
n = self.attn1.to_out(n)
|
||||
else:
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
@@ -786,7 +845,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
@@ -1017,7 +1076,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
|
||||
B, S, C = x_mix.shape
|
||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
|
||||
x_mix = rearrange(
|
||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
|
||||
|
||||
|
||||
#################################################################################
|
||||
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
|
||||
return x
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
if use_checkpoint:
|
||||
@@ -604,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
return _block_mixing(*args, **kwargs)
|
||||
|
||||
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
@@ -620,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
attn = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
@@ -635,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
@@ -956,10 +960,10 @@ class MMDiT(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
context = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@@ -968,6 +972,7 @@ class MMDiT(nn.Module):
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
@@ -145,7 +145,7 @@ class Downsample(nn.Module):
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512, conv_op=ops.Conv2d):
|
||||
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.norm1 = norm_op(in_channels)
|
||||
self.conv1 = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
@@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.norm2 = norm_op(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = conv_op(out_channels,
|
||||
out_channels,
|
||||
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
def forward(self, x, temb=None):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.swish(h)
|
||||
@@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
@@ -305,11 +305,11 @@ def vae_attention():
|
||||
return normal_attention
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.norm = norm_op(in_channels)
|
||||
self.q = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user