Compare commits

..

198 Commits

Author SHA1 Message Date
Jedrzej Kosinski
0f21df8051 Merge branch 'master' into portable-manager-update 2025-12-15 16:38:56 -08:00
Jedrzej Kosinski
334e74b938 Merge branch 'master' into portable-manager-update 2025-12-08 20:25:28 -08:00
Jedrzej Kosinski
f19a2b53f4 Create install_manager scripts, make update.py attempt to update comfyui_manager package if already installed, add --enable-manager startup arg to all run scripts 2025-11-26 16:42:29 -08:00
Jedrzej Kosinski
7cf52dd51e Make manager warning for --enable-manager not appear if is windows_standalone_build 2025-11-26 16:32:23 -08:00
Dr.Lt.Data
4061eaa469 updated: manager_requirements.txt 2025-11-26 22:39:19 +09:00
Dr.Lt.Data
d69c8b3ac2 updated: manager_requirements.txt 2025-11-26 22:16:40 +09:00
Dr.Lt.Data
a5e0674474 Merge branch 'master' into dr-support-pip-cm 2025-11-26 21:44:25 +09:00
Jedrzej Kosinski
79fb96488a Move manager requirement into its own file 2025-11-25 20:43:23 -08:00
Jedrzej Kosinski
aa878cc193 Merge branch 'master' into dr-support-pip-cm 2025-11-25 20:41:19 -08:00
Dr.Lt.Data
1abf69ea27 Merge branch 'master' into dr-support-pip-cm 2025-11-24 23:34:42 +09:00
Dr.Lt.Data
6206a6d3d2 Merge branch 'master' into dr-support-pip-cm 2025-11-18 23:08:08 +09:00
Dr.Lt.Data
7b9ad5208e Merge branch 'master' into dr-support-pip-cm 2025-11-17 00:56:17 +09:00
Dr.Lt.Data
a58c4fbf68 Merge branch 'master' into dr-support-pip-cm 2025-11-15 08:33:49 +09:00
Dr.Lt.Data
b15ef9917b Merge branch 'master' into dr-support-pip-cm 2025-11-11 01:58:42 +09:00
Dr.Lt.Data
2d4dd3972c Merge branch 'master' into dr-support-pip-cm 2025-11-10 12:48:44 +09:00
Dr.Lt.Data
5cb77fbb18 Merge branch 'master' into dr-support-pip-cm 2025-11-06 00:55:10 +09:00
Dr.Lt.Data
32bd55779a Merge branch 'master' into dr-support-pip-cm 2025-11-05 07:42:29 +09:00
Dr.Lt.Data
671a769dc6 Merge branch 'master' into dr-support-pip-cm 2025-11-04 23:25:51 +09:00
Dr.Lt.Data
d8b821e47b Merge branch 'master' into dr-support-pip-cm 2025-11-03 07:12:55 +09:00
Dr.Lt.Data
16359abbbc Merge branch 'master' into dr-support-pip-cm 2025-11-01 06:27:21 +09:00
Dr.Lt.Data
8f492d8f34 Merge branch 'master' into dr-support-pip-cm 2025-10-31 12:55:36 +09:00
Dr.Lt.Data
ad4b959d7e Merge branch 'master' into dr-support-pip-cm 2025-10-31 07:31:50 +09:00
Dr.Lt.Data
b88c66bfa1 Merge branch 'master' into dr-support-pip-cm 2025-10-30 07:30:50 +09:00
Dr.Lt.Data
de357a01f8 Merge branch 'master' into dr-support-pip-cm 2025-10-28 19:01:11 +09:00
Dr.Lt.Data
c07908a37e Merge branch 'master' into dr-support-pip-cm 2025-10-27 12:50:24 +09:00
Dr.Lt.Data
fe26f30cb6 Merge branch 'master' into dr-support-pip-cm 2025-10-26 12:52:08 +09:00
Dr.Lt.Data
3c4b429251 Merge branch 'master' into dr-support-pip-cm 2025-10-25 10:42:34 +09:00
Dr.Lt.Data
0432bccbcf Merge branch 'master' into dr-support-pip-cm 2025-10-24 12:17:46 +09:00
Dr.Lt.Data
aaf06ace12 Merge branch 'master' into dr-support-pip-cm 2025-10-23 06:54:58 +09:00
Dr.Lt.Data
f46771bd97 update requirements.txt 2025-10-21 12:35:02 +09:00
Dr.Lt.Data
8e1b1b722b Merge branch 'master' into dr-support-pip-cm 2025-10-21 12:34:57 +09:00
Dr.Lt.Data
a1a6f4d7fe Merge branch 'master' into dr-support-pip-cm 2025-10-21 07:26:53 +09:00
Dr.Lt.Data
ee54914a52 Merge branch 'master' into dr-support-pip-cm 2025-10-20 06:35:52 +09:00
Dr.Lt.Data
8f59e2a341 Merge branch 'master' into dr-support-pip-cm 2025-10-19 11:39:42 +09:00
Dr.Lt.Data
7d5e73ea94 Merge branch 'master' into dr-support-pip-cm 2025-10-19 09:37:12 +09:00
Dr.Lt.Data
9dd26b0349 Merge branch 'master' into dr-support-pip-cm 2025-10-18 07:22:23 +09:00
Dr.Lt.Data
c9c68ed78d Merge branch 'master' into dr-support-pip-cm 2025-10-17 22:37:13 +09:00
Dr.Lt.Data
6626f7c5c4 Merge branch 'master' into dr-support-pip-cm 2025-10-17 12:42:54 +09:00
Dr.Lt.Data
0802f3a635 Merge branch 'master' into dr-support-pip-cm 2025-10-16 12:06:19 +09:00
Dr.Lt.Data
19ad129d37 Merge branch 'master' into dr-support-pip-cm 2025-10-16 06:40:04 +09:00
Dr.Lt.Data
db61dc3481 Merge branch 'master' into dr-support-pip-cm 2025-10-15 12:34:12 +09:00
Dr.Lt.Data
5fbc8a1b80 Merge branch 'master' into dr-support-pip-cm 2025-10-15 06:43:20 +09:00
Dr.Lt.Data
b180f47d0e Merge branch 'master' into dr-support-pip-cm 2025-10-14 12:34:58 +09:00
Dr.Lt.Data
2b47f4a38e Merge branch 'master' into dr-support-pip-cm 2025-10-14 07:36:42 +09:00
Dr.Lt.Data
a3af8f35c2 Merge branch 'master' into dr-support-pip-cm 2025-10-13 12:50:41 +09:00
Dr.Lt.Data
5f50b86114 Merge branch 'master' into dr-support-pip-cm 2025-10-13 06:42:04 +09:00
Dr.Lt.Data
4e7f2eeae2 Merge branch 'master' into dr-support-pip-cm 2025-10-10 08:15:03 +09:00
Dr.Lt.Data
fc5703c468 Merge branch 'master' into dr-support-pip-cm 2025-10-09 23:57:10 +09:00
Dr.Lt.Data
05cd5348b6 Merge branch 'master' into dr-support-pip-cm 2025-10-09 10:49:23 +09:00
Dr.Lt.Data
3c000c1de4 Merge branch 'master' into dr-support-pip-cm 2025-10-08 11:04:18 +09:00
Dr.Lt.Data
6b20418ad1 Merge branch 'master' into dr-support-pip-cm 2025-10-07 14:30:16 +09:00
Dr.Lt.Data
2dc24f9870 Merge branch 'master' into dr-support-pip-cm 2025-10-05 07:36:33 +09:00
Dr.Lt.Data
8634b19bc7 Merge branch 'master' into dr-support-pip-cm 2025-10-04 07:09:43 +09:00
Dr.Lt.Data
47436c59d7 Merge branch 'master' into dr-support-pip-cm 2025-10-03 10:23:40 +09:00
Dr.Lt.Data
28092933c1 Merge branch 'master' into dr-support-pip-cm 2025-10-02 12:49:48 +09:00
Dr.Lt.Data
17064a993c Merge branch 'master' into dr-support-pip-cm 2025-10-02 07:31:37 +09:00
Dr.Lt.Data
12f2b59284 Merge branch 'master' into dr-support-pip-cm 2025-10-01 07:17:25 +09:00
Dr.Lt.Data
8cbdaa8855 Merge branch 'master' into dr-support-pip-cm 2025-09-30 12:46:12 +09:00
Dr.Lt.Data
976cee95f8 Merge branch 'master' into dr-support-pip-cm 2025-09-30 06:54:59 +09:00
Dr.Lt.Data
20ac0052f8 Merge branch 'master' into dr-support-pip-cm 2025-09-29 06:58:35 +09:00
Dr.Lt.Data
bc8418f55a Merge branch 'master' into dr-support-pip-cm 2025-09-26 07:00:43 +09:00
Dr.Lt.Data
42f69b1ffd Merge branch 'master' into dr-support-pip-cm 2025-09-25 07:25:27 +09:00
Dr.Lt.Data
581059a83d Merge branch 'master' into dr-support-pip-cm 2025-09-24 07:24:19 +09:00
Dr.Lt.Data
74c1a58566 Merge branch 'master' into dr-support-pip-cm 2025-09-23 07:28:52 +09:00
Dr.Lt.Data
316aa125c9 Merge branch 'master' into dr-support-pip-cm 2025-09-22 12:33:09 +09:00
Dr.Lt.Data
7b1ed9b2b8 Merge branch 'master' into dr-support-pip-cm 2025-09-21 11:24:37 +09:00
Dr.Lt.Data
4ea946778b Merge branch 'master' into dr-support-pip-cm 2025-09-21 10:45:28 +09:00
Dr.Lt.Data
309c92d6c9 Merge branch 'master' into dr-support-pip-cm 2025-09-21 09:33:38 +09:00
Dr.Lt.Data
ca7492c9d4 Merge branch 'master' into dr-support-pip-cm 2025-09-20 07:13:36 +09:00
Dr.Lt.Data
267c54eaae Updated comfyui_manager to version 4.0.2 in requirements.txt 2025-09-19 12:00:17 +09:00
Dr.Lt.Data
fa51f0c60a Merge branch 'master' into dr-support-pip-cm 2025-09-19 12:00:10 +09:00
Dr.Lt.Data
0a084a88a2 Merge branch 'master' into dr-support-pip-cm 2025-09-19 08:16:58 +09:00
Dr.Lt.Data
036aa3efa8 fixed: Even if --enable-manager is applied, it should switch to a disabled state if comfyui_manager is not installed. 2025-09-19 07:38:10 +09:00
comfyanonymous
e7ff647d02 --disable-manager -> --enable-manager 2025-09-17 20:58:42 -04:00
Dr.Lt.Data
77e10752fe Merge branch 'master' into dr-support-pip-cm 2025-09-18 07:32:23 +09:00
Dr.Lt.Data
2c30881d9c Merge branch 'master' into dr-support-pip-cm 2025-09-17 11:56:35 +09:00
Dr.Lt.Data
7fa5990dbc Merge branch 'master' into dr-support-pip-cm 2025-09-17 06:09:40 +09:00
Dr.Lt.Data
07212a2466 Merge branch 'master' into dr-support-pip-cm 2025-09-16 12:39:43 +09:00
Dr.Lt.Data
f4d7a32cd8 Merge branch 'master' into dr-support-pip-cm 2025-09-15 12:16:00 +09:00
Dr.Lt.Data
ce1df28bef Merge branch 'master' into dr-support-pip-cm 2025-09-13 15:41:22 +09:00
Dr.Lt.Data
0f8d57206c Update comfyui_manager dependency in requirements 2025-09-13 08:18:31 +09:00
Dr.Lt.Data
9d70d75f20 Merge branch 'master' into dr-support-pip-cm 2025-09-13 07:30:55 +09:00
Dr.Lt.Data
ff5e92abdb Merge branch 'master' into dr-support-pip-cm 2025-09-12 12:32:12 +09:00
Dr.Lt.Data
033e725b8e Merge branch 'master' into dr-support-pip-cm 2025-09-12 07:53:36 +09:00
Dr.Lt.Data
cc8a026671 Merge branch 'master' into dr-support-pip-cm 2025-09-11 12:28:47 +09:00
Dr.Lt.Data
f9cfea0f2e Merge branch 'master' into dr-support-pip-cm 2025-09-11 06:51:11 +09:00
Dr.Lt.Data
b5745ae0a7 Merge branch 'master' into dr-support-pip-cm 2025-09-10 18:37:03 +09:00
Dr.Lt.Data
2a30a19df7 Merge branch 'master' into dr-support-pip-cm 2025-09-10 11:51:21 +09:00
Dr.Lt.Data
c7f04234c6 Merge branch 'master' into dr-support-pip-cm 2025-09-10 07:11:31 +09:00
Dr.Lt.Data
0e31eca087 Merge branch 'master' into dr-support-pip-cm 2025-09-09 07:42:02 +09:00
Dr.Lt.Data
1c8c9f7f4d Merge branch 'master' into dr-support-pip-cm 2025-09-08 12:33:17 +09:00
Dr.Lt.Data
d4cb177414 Merge pull request #2 from viva-jinyi/fix/system-os
Fix OS reporting in /system_stats API to use sys.platform
2025-09-08 07:50:53 +09:00
Dr.Lt.Data
8a2f805233 Merge branch 'master' into dr-support-pip-cm 2025-09-08 07:44:18 +09:00
Jin Yi
c97f6aa0b2 Fix OS reporting in /system_stats API to use sys.platform
Replace os.name with sys.platform for more detailed OS identification.
This change provides better OS differentiation:
- Windows: "nt" -> "win32"
- macOS: "posix" -> "darwin"
- Linux: "posix" -> "linux"

Previously, both macOS and Linux returned "posix", making them
indistinguishable. Now each OS has a unique identifier, aligning
with the Registry Specifications for proper compatibility checks.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-07 14:29:45 +09:00
Dr.Lt.Data
ba3c8e3dbe Merge branch 'master' into dr-support-pip-cm 2025-09-06 04:05:01 +09:00
Dr.Lt.Data
3b65618d13 update requirements 2025-09-06 03:40:20 +09:00
Dr.Lt.Data
fc88b4f939 Merge branch 'master' into dr-support-pip-cm 2025-09-05 18:49:55 +09:00
Dr.Lt.Data
6a1f95caa0 Merge branch 'master' into dr-support-pip-cm 2025-09-05 08:14:53 +09:00
Dr.Lt.Data
2ade597d02 Merge branch 'master' into dr-support-pip-cm 2025-09-04 11:55:51 +09:00
Dr.Lt.Data
08e9c3ddf0 fixed: more robust detection of missing comfyui_manager 2025-09-04 11:51:11 +09:00
Dr.Lt.Data
f8aab7cab0 fixed: more robust detection of missing comfyui_manager 2025-09-04 11:49:48 +09:00
Dr.Lt.Data
561eaf6ccf fixed: Robust detection of missing comfyui_manager 2025-09-04 11:44:53 +09:00
Dr.Lt.Data
31469f962f fixed: issue of not properly detecting the removal of the comfyui_manager package in a conda environment. 2025-09-04 11:31:37 +09:00
Dr.Lt.Data
e0f111c6eb Merge branch 'master' into dr-support-pip-cm 2025-09-04 08:30:56 +09:00
Dr.Lt.Data
74a027f589 Merge branch 'master' into dr-support-pip-cm 2025-09-03 12:01:02 +09:00
Dr.Lt.Data
cc21e84115 Merge branch 'master' into dr-support-pip-cm 2025-09-03 00:07:37 +09:00
Dr.Lt.Data
69bbe1d5a9 modified: SERVER_FEATURE_FLAGS - manager extension is added 2025-09-02 07:44:17 +09:00
Dr.Lt.Data
7fd87423b3 Merge branch 'master' into dr-support-pip-cm 2025-08-31 17:16:00 +09:00
Dr.Lt.Data
1224d58a17 Merge branch 'master' into dr-support-pip-cm 2025-08-30 06:12:16 +09:00
Dr.Lt.Data
2cc7bafb52 Merge branch 'master' into dr-support-pip-cm 2025-08-29 07:36:57 +09:00
Dr.Lt.Data
523b54b9b4 update requirements.txt 2025-08-28 00:29:56 +09:00
Dr.Lt.Data
5c8f724c9a Merge branch 'master' into dr-support-pip-cm 2025-08-28 00:28:43 +09:00
Dr.Lt.Data
eda556d7b4 Merge branch 'dr-support-pip-cm' 2025-08-26 19:45:36 +09:00
Dr.Lt.Data
b7faa5fe3d Merge branch 'master' into dr-support-pip-cm 2025-08-25 06:08:20 +09:00
Dr.Lt.Data
6087e0210c modified: Changed behavior so that if comfyui-manager is not installed, it provides an installation guide message instead of raising an exception. 2025-08-24 16:05:10 +09:00
Dr.Lt.Data
6728792589 Merge branch 'master' into dr-support-pip-cm 2025-08-24 15:43:42 +09:00
Dr.Lt.Data
881db45147 Merge branch 'master' into dr-support-pip-cm 2025-08-23 17:46:43 +09:00
Dr.Lt.Data
26cac3c053 restore custom_nodes dir 2025-08-23 08:47:27 +09:00
Dr.Lt.Data
47350d323a Merge branch 'master' into dr-support-pip-cm 2025-08-23 06:46:25 +09:00
Dr.Lt.Data
117d8ae992 update requirments.txt 2025-08-23 06:45:52 +09:00
Dr.Lt.Data
844e5e7abb Merge branch 'master' into dr-support-pip-cm 2025-08-22 20:00:27 +09:00
Dr.Lt.Data
20953cbfd4 Merge branch 'master' into dr-support-pip-cm 2025-08-22 12:41:27 +09:00
Dr.Lt.Data
7c36368b14 Merge branch 'master' into dr-support-pip-cm 2025-08-22 05:16:03 +09:00
Dr.Lt.Data
d7b4f45c5b Merge branch 'master' into dr-support-pip-cm 2025-08-21 06:44:35 +09:00
Dr.Lt.Data
4b1aac74bb Merge branch 'master' into dr-support-pip-cm 2025-08-20 12:25:03 +09:00
Dr.Lt.Data
be456cb37a Merge branch 'master' into dr-support-pip-cm 2025-08-20 04:02:37 +09:00
Dr.Lt.Data
3dfecd541b Merge branch 'master' into dr-support-pip-cm 2025-08-19 06:24:55 +09:00
Dr.Lt.Data
ca04f8f401 Merge branch 'master' into dr-support-pip-cm 2025-08-18 12:23:05 +09:00
Dr.Lt.Data
8b44e58e6c Merge branch 'master' into dr-support-pip-cm 2025-08-18 07:35:15 +09:00
Dr.Lt.Data
37aa552602 Merge branch 'master' into dr-support-pip-cm 2025-08-15 10:10:54 +09:00
Dr.Lt.Data
91555acf2c Merge branch 'master' into dr-support-pip-cm 2025-08-14 12:01:56 +09:00
Dr.Lt.Data
d7777dc83a Merge branch 'master' into dr-support-pip-cm 2025-08-14 02:36:19 +09:00
Dr.Lt.Data
1c66507261 Merge branch 'master' into dr-support-pip-cm 2025-08-13 12:12:22 +09:00
Dr.Lt.Data
264116dc4d Merge branch 'master' into dr-support-pip-cm 2025-08-12 10:13:31 +09:00
Dr.Lt.Data
d750aa0847 Merge branch 'master' into dr-support-pip-cm 2025-08-11 22:22:29 +09:00
Dr.Lt.Data
37277e4188 Merge branch 'master' into dr-support-pip-cm 2025-08-10 20:57:20 +09:00
Dr.Lt.Data
106510197a Merge branch 'master' into dr-support-pip-cm 2025-08-08 23:48:53 +09:00
Dr.Lt.Data
bf01579b87 Merge branch 'master' into dr-support-pip-cm 2025-08-08 12:07:08 +09:00
Dr.Lt.Data
ab1a79ad74 Merge branch 'master' into dr-support-pip-cm 2025-08-07 12:20:12 +09:00
Dr.Lt.Data
2fe58571e2 Merge branch 'master' into dr-support-pip-cm 2025-08-07 07:45:14 +09:00
Dr.Lt.Data
46209599ff Merge branch 'master' into dr-support-pip-cm 2025-08-05 12:24:25 +09:00
Dr.Lt.Data
02317a1f71 Merge branch 'master' into dr-support-pip-cm 2025-08-05 06:21:27 +09:00
Dr.Lt.Data
ac7e83448e Merge branch 'master' into dr-support-pip-cm 2025-08-04 07:25:20 +09:00
Dr.Lt.Data
56cff964f2 Merge branch 'master' into dr-support-pip-cm 2025-08-01 12:40:30 +09:00
Dr.Lt.Data
5582e2a0f3 Merge branch 'master' into dr-support-pip-cm 2025-07-31 12:33:38 +09:00
Dr.Lt.Data
3c8196a170 Merge branch 'master' into dr-support-pip-cm 2025-07-30 12:14:34 +09:00
Dr.Lt.Data
62c08e4659 Merge branch 'master' into dr-support-pip-cm 2025-07-29 23:44:44 +09:00
Dr.Lt.Data
ac7bde1d03 Merge branch 'master' into dr-support-pip-cm 2025-07-29 12:13:25 +09:00
Dr.Lt.Data
6909638a42 Merge branch 'master' into dr-support-pip-cm 2025-07-27 15:01:02 +09:00
Dr.Lt.Data
d0625d7f7c Merge branch 'master' into dr-support-pip-cm 2025-07-26 09:35:21 +09:00
Dr.Lt.Data
6b19857c93 Merge branch 'master' into dr-support-pip-cm 2025-07-25 12:21:17 +09:00
Dr.Lt.Data
4e904305ce Merge branch 'dr-support-pip-cm' 2025-07-24 12:22:50 +09:00
Dr.Lt.Data
726aa75126 Merge branch 'master' into dr-support-pip-cm 2025-07-23 12:57:43 +09:00
Dr.Lt.Data
74087e26da Merge branch 'master' into dr-support-pip-cm 2025-07-22 07:41:54 +09:00
Dr.Lt.Data
51bf04c5ae Merge branch 'master' into dr-support-pip-cm 2025-07-21 12:15:35 +09:00
Dr.Lt.Data
b603e034e5 Merge branch 'master' into dr-support-pip-cm 2025-07-20 16:31:14 +09:00
Dr.Lt.Data
3c9a0fcf8a Merge branch 'master' into dr-support-pip-cm 2025-07-17 12:23:03 +09:00
Dr.Lt.Data
0adeb9b135 Merge branch 'master' into dr-support-pip-cm 2025-07-15 12:02:07 +09:00
Dr.Lt.Data
98b5183ed8 Merge branch 'master' into dr-support-pip-cm 2025-07-15 06:46:20 +09:00
Dr.Lt.Data
16a0b24da4 Merge branch 'master' into dr-support-pip-cm 2025-07-12 09:19:32 +09:00
Dr.Lt.Data
552fe9df02 Merge branch 'master' into dr-support-pip-cm 2025-07-08 12:34:29 +09:00
Dr.Lt.Data
2ce64b131c Merge branch 'master' into dr-support-pip-cm 2025-07-04 06:35:21 +09:00
Dr.Lt.Data
d6fa7a7c84 Merge branch 'master' into dr-support-pip-cm 2025-07-02 12:03:03 +09:00
Dr.Lt.Data
17cfabec7d added: Apply manager middleware 2025-07-01 12:55:53 +09:00
Dr.Lt.Data
ad633b2953 Merge branch 'master' into dr-support-pip-cm 2025-07-01 12:55:47 +09:00
Dr.Lt.Data
9eba1547f4 Merge branch 'master' into dr-support-pip-cm 2025-06-29 15:31:19 +09:00
Dr.Lt.Data
f398256d11 Merge branch 'master' into dr-support-pip-cm 2025-06-28 10:53:05 +09:00
Dr.Lt.Data
8744ebb4a1 Merge branch 'master' into dr-support-pip-cm 2025-06-27 07:34:33 +09:00
Dr.Lt.Data
d5167d2ded Merge branch 'master' into dr-support-pip-cm 2025-06-26 08:59:09 +09:00
Dr.Lt.Data
364e07d145 Merge branch 'master' into dr-support-pip-cm 2025-06-25 00:26:24 +09:00
Dr.Lt.Data
5a0ec182ec Merge branch 'master' into dr-support-pip-cm 2025-06-23 07:08:23 +09:00
Dr.Lt.Data
39f39c3aa9 Merge branch 'master' into dr-support-pip-cm 2025-06-21 23:51:13 +09:00
Dr.Lt.Data
4e95c0c104 Merge branch 'master' into dr-support-pip-cm 2025-06-20 22:12:07 +09:00
Dr.Lt.Data
d1ab6adc3a Merge branch 'master' into dr-support-pip-cm 2025-06-16 06:38:35 +09:00
Dr.Lt.Data
35a294431f Merge branch 'master' into dr-support-pip-cm 2025-06-09 12:34:23 +09:00
Dr.Lt.Data
baeeeb02b9 Merge branch 'master' into dr-support-pip-cm 2025-06-01 04:34:00 +09:00
Dr.Lt.Data
ef641f3e4b Merge branch 'master' into dr-support-pip-cm 2025-05-26 02:23:34 +09:00
Dr.Lt.Data
9ac185456f Merge branch 'master' into dr-support-pip-cm 2025-05-19 06:04:10 +09:00
Dr.Lt.Data
b69ef5f869 Merge branch 'master' into dr-support-pip-cm 2025-05-10 18:46:26 +09:00
Dr.Lt.Data
31aecbe1ad Merge branch 'master' into dr-support-pip-cm 2025-05-09 06:38:49 +09:00
Dr.Lt.Data
28d23a7813 Merge branch 'master' into dr-support-pip-cm 2025-05-03 22:38:35 +09:00
Dr.Lt.Data
f51047abd3 Merge branch 'master' into dr-support-pip-cm 2025-05-01 02:09:54 +09:00
Dr.Lt.Data
14598c1104 Merge branch 'master' into dr-support-pip-cm 2025-04-28 23:22:56 +09:00
Dr.Lt.Data
57dae1469f modified: --disable-manager will prevent importing comfyui-manager
feat: --disable-manager-ui will disable the endpoints and ui of comfyui-manager
2025-04-28 17:56:50 +09:00
Dr.Lt.Data
ea3d3cc6a4 Merge branch 'master' into dr-support-pip-cm 2025-04-24 08:45:22 +09:00
Dr.Lt.Data
9c2eb2c1dd Merge branch 'master' into dr-support-pip-cm 2025-04-22 02:24:26 +09:00
Dr.Lt.Data
ec82eea1f1 Merge branch 'master' into dr-support-pip-cm 2025-04-21 12:06:29 +09:00
Dr.Lt.Data
4fafc0c58d Merge branch 'master' into dr-support-pip-cm 2025-04-20 19:08:51 +09:00
Dr.Lt.Data
d2ed1dcb9a Merge branch 'master' into dr-support-pip-cm 2025-04-15 23:04:27 +09:00
Dr.Lt.Data
94f61c6378 add --enable-manager-legacy-ui 2025-04-15 01:36:17 +09:00
Dr.Lt.Data
418eaed42c fixed: Ensure that comfyui_manager's prestartup always runs, even when --disable-all-custom-nodes is used.
feat: Disable specific custom nodes according to the policy of `comfyui_manager`.
2025-04-12 21:53:57 +09:00
Dr.Lt.Data
cc975e5f0b add comfyui_manager to requirements.txt
It's still in the development stage, so the version is not pinned yet.
2025-04-12 19:11:02 +09:00
Dr.Lt.Data
311f64ac83 Merge branch 'master' into dr-support-pip-cm 2025-04-12 19:08:15 +09:00
Dr.Lt.Data
8b9f31abdf fixed: ruff check 2025-04-10 12:10:24 +09:00
Dr.Lt.Data
fb1b9c76b0 added: --disable-manager option 2025-04-10 08:40:53 +09:00
Dr.Lt.Data
545d96c12d Merge branch 'master' into dr-support-pip-cm 2025-04-10 08:34:54 +09:00
Dr.Lt.Data
1855efe1c3 Merge branch 'comfyanonymous:master' into dr-support-pip-cm 2025-04-05 15:21:55 +09:00
Dr.Lt.Data
6897a1d077 support pip comfyui-manager 2025-03-19 22:24:04 +09:00
266 changed files with 4011 additions and 30481 deletions

View File

@@ -0,0 +1,4 @@
@echo off
..\python_embeded\python.exe .\install_manager.py ..\ComfyUI\
echo Installed manager through pip package, if not already installed.
pause

View File

@@ -0,0 +1,24 @@
import sys
import os
repo_path = str(sys.argv[1])
repo_manager_req_path = os.path.join(repo_path, "manager_requirements.txt")
if os.path.exists(repo_manager_req_path):
import subprocess
# if not installed, we get 'WARNING: Package(s) not found: comfyui_manager'
# if installed, there will be a line like 'Version: 0.1.0' = False
try:
output = subprocess.check_output([sys.executable, '-s', '-m', 'pip', 'show', 'comfyui_manager'])
if 'Version:' in output.decode('utf-8'):
print("comfyui_manager is already installed, will attempt to update to matching version of ComfyUI.") # noqa: T201
else:
print("comfyui_manager is not installed, will install it now.") # noqa: T201
except:
pass
try:
subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_manager_req_path])
print("comfyui_manager installed successfully.") # noqa: T201
except:
print("Failed to install comfyui_manager, please install it manually.") # noqa: T201

View File

@@ -126,6 +126,8 @@ cur_path = os.path.dirname(update_py_path)
req_path = os.path.join(cur_path, "current_requirements.txt")
repo_req_path = os.path.join(repo_path, "requirements.txt")
manager_req_path = os.path.join(cur_path, "current_manager_requirements.txt")
repo_manager_req_path = os.path.join(repo_path, "manager_requirements.txt")
def files_equal(file1, file2):
try:
@@ -152,6 +154,25 @@ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
except:
pass
if os.path.exists(repo_manager_req_path) and (not os.path.exists(manager_req_path) or not files_equal(repo_manager_req_path, manager_req_path)):
import subprocess
# first, confirm that comfyui_manager package is installed; only update it if it is
# if not installed, we get 'WARNING: Package(s) not found: comfyui_manager'
# if installed, there will be a line like 'Version: 0.1.0'
update_manager = False
try:
output = subprocess.check_output([sys.executable, '-s', '-m', 'pip', 'show', 'comfyui_manager'])
if 'Version:' in output.decode('utf-8'):
update_manager = True
except:
pass
if update_manager:
try:
subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_manager_req_path])
shutil.copy(repo_manager_req_path, manager_req_path)
except:
pass
stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")

View File

@@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-manager
pause

View File

@@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-manager --disable-smart-memory
pause

View File

@@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-manager --fast
pause

View File

@@ -1,3 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --enable-manager --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build --enable-manager
pause

View File

@@ -1,3 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-manager
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,3 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-manager --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -20,7 +20,7 @@ jobs:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu130"
python_minor: "13"
python_patch: "11"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
@@ -65,11 +65,11 @@ jobs:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 7.2"
name: "Release AMD ROCm 7.1.1"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm72"
cache_tag: "rocm711"
python_minor: "12"
python_patch: "10"
rel_name: "amd"

View File

@@ -117,7 +117,7 @@ jobs:
./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt

View File

@@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}

View File

@@ -5,7 +5,6 @@ on:
push:
branches:
- master
- release/**
paths-ignore:
- 'app/**'
- 'input/**'

View File

@@ -2,9 +2,9 @@ name: Execution Tests
on:
push:
branches: [ main, master, release/** ]
branches: [ main, master ]
pull_request:
branches: [ main, master, release/** ]
branches: [ main, master ]
jobs:
test:

View File

@@ -2,9 +2,9 @@ name: Test server launches without errors
on:
push:
branches: [ main, master, release/** ]
branches: [ main, master ]
pull_request:
branches: [ main, master, release/** ]
branches: [ main, master ]
jobs:
test:
@@ -13,7 +13,7 @@ jobs:
- name: Checkout ComfyUI
uses: actions/checkout@v4
with:
repository: "Comfy-Org/ComfyUI"
repository: "comfyanonymous/ComfyUI"
path: "ComfyUI"
- uses: actions/setup-python@v4
with:
@@ -32,9 +32,7 @@ jobs:
working-directory: ComfyUI
- name: Check for unhandled exceptions in server log
run: |
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
cat console_output_filtered.log
if grep -qE "Exception|Error" console_output_filtered.log; then
if grep -qE "Exception|Error" console_output.log; then
echo "Unhandled exception/error found in server log."
exit 1
fi

View File

@@ -2,9 +2,9 @@ name: Unit Tests
on:
push:
branches: [ main, master, release/** ]
branches: [ main, master ]
pull_request:
branches: [ main, master, release/** ]
branches: [ main, master ]
jobs:
test:

View File

@@ -1,59 +0,0 @@
name: "CI: Update CI Container"
on:
release:
types: [published]
workflow_dispatch:
inputs:
version:
description: 'ComfyUI version (e.g., v0.7.0)'
required: true
type: string
jobs:
update-ci-container:
runs-on: ubuntu-latest
# Skip pre-releases unless manually triggered
if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease
steps:
- name: Get version
id: version
run: |
if [ "${{ github.event_name }}" = "release" ]; then
VERSION="${{ github.event.release.tag_name }}"
else
VERSION="${{ inputs.version }}"
fi
echo "version=$VERSION" >> $GITHUB_OUTPUT
- name: Checkout comfyui-ci-container
uses: actions/checkout@v4
with:
repository: comfy-org/comfyui-ci-container
token: ${{ secrets.CI_CONTAINER_PAT }}
- name: Check current version
id: current
run: |
CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown")
echo "current_version=$CURRENT" >> $GITHUB_OUTPUT
- name: Update Dockerfile
run: |
VERSION="${{ steps.version.outputs.version }}"
sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile
- name: Create Pull Request
id: create-pr
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.CI_CONTAINER_PAT }}
branch: automation/comfyui-${{ steps.version.outputs.version }}
title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
body: |
Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}`
**Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }}
labels: automation
commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"

View File

@@ -6,7 +6,6 @@ on:
- "pyproject.toml"
branches:
- master
- release/**
jobs:
update-version:

View File

@@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "11"
default: "9"
# push:
# branches:
# - master

View File

@@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -119,9 +119,6 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
@@ -183,7 +180,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
If you have trouble extracting it, right click the file -> properties -> unblock
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
@@ -208,12 +205,10 @@ comfy install
## Manual Install (Windows, Linux)
Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported.
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
### Instructions:
Git clone this repo.
@@ -229,7 +224,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
@@ -240,7 +235,7 @@ These have less hardware support than the builds above but they work on windows.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/```
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):

View File

@@ -1,174 +0,0 @@
"""
Initial assets schema
Revision ID: 0001_assets
Revises: None
Create Date: 2025-12-10 00:00:00
"""
from alembic import op
import sqlalchemy as sa
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text_encoders", "tag_type": "system"},
{"name": "diffusion_models", "tag_type": "system"},
{"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

View File

@@ -1,514 +0,0 @@
import logging
import uuid
import urllib.parse
import os
import contextlib
from aiohttp import web
from pydantic import ValidationError
import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
from app.assets.scanner import seed_assets
import folder_paths
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
# Note to any custom node developers reading this code:
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
"""
GET request to list assets.
"""
query_dict = get_query_dict(request)
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
"""
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as e:
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
except Exception:
logging.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
file_size = os.path.getsize(abs_path)
logging.info(
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
abs_path,
file_size,
file_size / (1024 * 1024),
content_type,
filename,
)
async def file_sender():
chunk_size = 64 * 1024
with open(abs_path, "rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
yield chunk
return web.Response(
body=file_sender(),
content_type=content_type,
headers={
"Content-Disposition": cd,
"Content-Length": str(file_size),
},
)
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
if not (request.content_type or "").lower().startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
logging.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
"""
GET request to list all tags based on query parameters.
"""
query_map = dict(request.rel_url.query)
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as e:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
status=400,
)
result = manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed")
async def seed_assets_endpoint(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output)."""
try:
payload = await request.json()
roots = payload.get("roots", ["models", "input", "output"])
except Exception:
roots = ["models", "input", "output"]
valid_roots = [r for r in roots if r in ("models", "input", "output")]
if not valid_roots:
return _error_response(400, "INVALID_BODY", "No valid roots specified")
try:
seed_assets(tuple(valid_roots))
except Exception:
logging.exception("seed_assets failed for roots=%s", valid_roots)
return _error_response(500, "INTERNAL", "Seed operation failed")
return web.json_response({"seeded": valid_roots}, status=200)

View File

@@ -1,264 +0,0 @@
import json
from typing import Any, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
# Accept either a JSON string (query param) or a dict
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
prefix: str | None = Field(None, min_length=1, max_length=256)
limit: int = Field(100, ge=1, le=1000)
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
@field_validator("prefix")
@classmethod
def normalize_prefix(cls, v: str | None) -> str | None:
if v is None:
return v
v = v.strip()
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self

View File

@@ -1,93 +0,0 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetsList(BaseModel):
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)

View File

@@ -1,204 +0,0 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@@ -1,233 +0,0 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list[Tag]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list[AssetInfo]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -1,976 +0,0 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@@ -1,62 +0,0 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@@ -1,75 +0,0 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -1,312 +0,0 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Any
import folder_paths
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
Normalize a list of tags by:
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@@ -1,516 +0,0 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

View File

@@ -1,263 +0,0 @@
import contextlib
import time
import logging
import os
import sqlalchemy
import folder_paths
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""
Scan the given roots and seed the assets into the database.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None

View File

@@ -1,21 +1,14 @@
from typing import Any
from datetime import datetime
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import declarative_base
class Base(DeclarativeBase):
pass
Base = declarative_base()
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
def to_dict(obj):
fields = obj.__table__.columns.keys()
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
# TODO: Define models here

View File

@@ -44,7 +44,7 @@ class ModelFileManager:
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if folder not in folder_paths.folder_names_and_paths:
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@@ -55,7 +55,7 @@ class ModelFileManager:
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if folder_name not in folder_paths.folder_names_and_paths:
if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]

View File

@@ -10,7 +10,6 @@ import hashlib
class Source:
custom_node = "custom_node"
templates = "templates"
class SubgraphEntry(TypedDict):
source: str
@@ -39,18 +38,6 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": source,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": {"node_pack": node_pack},
}
return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
@@ -73,60 +60,53 @@ class SubgraphManager:
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
"""Load subgraphs from custom nodes."""
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, "*/subgraphs/*.json")
for file in glob.glob(pattern):
file = file.replace('\\', '/')
node_pack = "custom_nodes." + file.split('/')[-3]
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
subgraphs_dict[entry_id] = entry
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_blueprint_subgraphs(self, force_reload=False):
"""Load subgraphs from the blueprints directory."""
if not force_reload and self.cached_blueprint_subgraphs is not None:
return self.cached_blueprint_subgraphs
subgraphs_dict: dict[SubgraphEntry] = {}
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
if os.path.exists(blueprints_dir):
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
file = file.replace('\\', '/')
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
subgraphs_dict[entry_id] = entry
self.cached_blueprint_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_all_subgraphs(self, loadedModules, force_reload=False):
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
return {**custom_node_subgraphs, **blueprint_subgraphs}
async def get_subgraph(self, id: str, loadedModules):
"""Get a specific subgraph by ID from any source."""
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
if entry is not None and entry.get('data') is None:
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_subgraph(id, loadedModules)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@@ -25,11 +25,11 @@ class AudioEncoderModel():
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@@ -159,7 +159,6 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
@@ -232,7 +231,6 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:
args = parser.parse_args()
@@ -258,6 +256,3 @@ elif args.fast == []:
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only

View File

@@ -1,59 +1,6 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
import math
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
def scale_dim(size, scale):
scaled = math.ceil(size * scale / patch_size) * patch_size
return max(patch_size, int(scaled))
# Binary search for optimal scale
lo, hi = eps / 10, 100.0
while hi - lo >= eps:
mid = (lo + hi) / 2
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
if (h // patch_size) * (w // patch_size) <= max_num_patches:
lo = mid
else:
hi = mid
return scale_dim(oh, lo), scale_dim(ow, lo)
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
if size > 0:
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
b, c, h, w = image.shape
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
@@ -209,27 +156,6 @@ class CLIPTextModel(torch.nn.Module):
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
return embeds + embed_weight
class Siglip2Embeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
super().__init__()
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
self.patch_size = patch_size
def forward(self, pixel_values):
b, c, h, w = pixel_values.shape
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
img = img.permute(0, 1, 3, 2, 4, 5)
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
img = self.patch_embedding(img)
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
@@ -273,11 +199,8 @@ class CLIPVision(torch.nn.Module):
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
if model_type in ["siglip2_vision_model"]:
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
else:
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:

View File

@@ -1,5 +1,6 @@
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os
import torch
import json
import logging
@@ -16,12 +17,28 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
@@ -33,10 +50,9 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
self.model_type = config.get("model_type", "clip_vision_model")
self.config = config.copy()
model_class = IMAGE_ENCODERS.get(self.model_type)
if self.model_type == "siglip_vision_model":
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False
@@ -47,26 +63,22 @@ class ClipVisionModel():
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
if self.model_type == "siglip2_vision_model":
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
else:
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0]
if self.return_all_hidden_states:
all_hs = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = all_hs[:, -2]
@@ -113,14 +125,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
if len(patch_embedding_shape) == 2:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
else:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")

View File

@@ -1,14 +0,0 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": -1,
"intermediate_size": 4304,
"model_type": "siglip2_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"num_patches": 256,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@@ -236,8 +236,6 @@ class ComfyNodeABC(ABC):
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
DEV_ONLY: bool
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
API_NODE: Optional[bool]
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""

View File

@@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
@@ -188,12 +188,6 @@ class IndexListContextHandler(ContextHandlerABC):
audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# Handle vace_context (temporal dim is 3)
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
vace_cond = cond_value.cond
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \

View File

@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling

View File

@@ -65,147 +65,3 @@ def stochastic_rounding(value, dtype, seed=0):
return output
return value.to(dtype=dtype)
# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
orig_shape = x.shape
sign = torch.signbit(x).to(torch.uint8)
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
x = x.abs()
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
mantissa = torch.where(
exp > 0,
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x * 2.0),
out=x
).round().to(torch.uint8)
del x
exp = exp.to(torch.uint8)
fp4 = (sign << 3) | (exp << 1) | mantissa
del sign, exp, mantissa
fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(orig_shape)[:-1] + [-1])
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
def ceil_div(a, b):
return (a + b - 1) // b
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
if flatten:
return rearranged.flatten()
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
x = x.view(orig_shape).nan_to_num()
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
return data_lp, scaled_block_scales_fp8
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
return x, to_blocked(blocked_scaled, flatten=False)
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
orig_shape = x.shape
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
orig_shape = list(orig_shape)
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
num_slices = max(1, (x.numel() / block_size))
slice_size = max(1, (round(x.shape[0] / num_slices)))
for i in range(0, x.shape[0], slice_size):
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
output_fp4[i:i + slice_size].copy_(fp4)
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)

View File

@@ -527,8 +527,7 @@ class HookKeyframeGroup:
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break
# if eval_c is outside the percent range, stop looking further
else:
break
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on

View File

@@ -1,12 +1,11 @@
import math
import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange as trange_, tqdm
from tqdm.auto import trange, tqdm
from . import utils
from . import deis
@@ -14,36 +13,6 @@ from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@@ -105,9 +74,6 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
def default_noise_sampler(x, seed=None):
if seed is not None:
if x.device == torch.device("cpu"):
seed += 1
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
else:
@@ -1652,17 +1618,6 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
@torch.no_grad()
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
@@ -1810,7 +1765,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x_pred = denoised
x = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
@@ -1831,7 +1786,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x_pred
return x
@torch.no_grad()

View File

@@ -8,7 +8,6 @@ class LatentFormat:
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8
def process_in(self, latent):
return latent * self.scale_factor
@@ -81,7 +80,6 @@ class SD_X4(LatentFormat):
class SC_Prior(LatentFormat):
latent_channels = 16
spacial_downscale_ratio = 42
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
@@ -104,7 +102,6 @@ class SC_Prior(LatentFormat):
]
class SC_B(LatentFormat):
spacial_downscale_ratio = 4
def __init__(self):
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
@@ -184,7 +181,6 @@ class Flux(SD3):
class Flux2(LatentFormat):
latent_channels = 128
spacial_downscale_ratio = 16
def __init__(self):
self.latent_rgb_factors =[
@@ -276,7 +272,6 @@ class Mochi(LatentFormat):
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
spacial_downscale_ratio = 32
def __init__(self):
self.latent_rgb_factors = [
@@ -412,11 +407,6 @@ class LTXV(LatentFormat):
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class LTXAV(LTXV):
def __init__(self):
self.latent_rgb_factors = None
self.latent_rgb_factors_bias = None
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
@@ -520,7 +510,6 @@ class Wan21(LatentFormat):
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
spacial_downscale_ratio = 16
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
@@ -598,7 +587,6 @@ class Wan22(Wan21):
class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
spacial_downscale_ratio = 32
scale_factor = 0.75289
latent_rgb_factors = [
@@ -732,7 +720,6 @@ class HunyuanVideo15(LatentFormat):
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
spacial_downscale_ratio = 16
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"
@@ -755,13 +742,8 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class ACEAudio15(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1
def __init__(self):
self.latent_rgb_factors = [

File diff suppressed because it is too large Load Diff

View File

@@ -1,202 +0,0 @@
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
import torch
from torch import nn
import torch.nn.functional as F
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.rope_theta = 10000
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
context = x if context is None else context
input_shape = x.shape[:-1]
q_shape = (*input_shape, self.n_heads, self.head_dim)
context_shape = context.shape[:-1]
kv_shape = (*context_shape, self.n_heads, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
if position_embeddings is not None:
assert position_embeddings_context is not None
cos, sin = position_embeddings
query_states = apply_rotary_pos_emb(query_states, cos, sin)
cos, sin = position_embeddings_context
key_states = apply_rotary_pos_emb(key_states, cos, sin)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
def init_weights(self):
torch.nn.init.zeros_(self.o_proj.weight)
class TransformerBlock(nn.Module):
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
super().__init__()
self.use_self_attn = use_self_attn
if self.use_self_attn:
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(
query_dim=model_dim,
context_dim=model_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
query_dim=model_dim,
context_dim=source_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.mlp = nn.Sequential(
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
nn.GELU(),
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
if self.use_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
def init_weights(self):
torch.nn.init.zeros_(self.mlp[2].weight)
self.cross_attn.init_weights()
class LLMAdapter(nn.Module):
def __init__(
self,
source_dim=1024,
target_dim=1024,
model_dim=1024,
num_layers=6,
num_heads=16,
use_self_attn=True,
layer_norm=False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
if model_dim != target_dim:
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
else:
self.in_proj = nn.Identity()
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
self.blocks = nn.ModuleList([
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
])
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
if target_attention_mask is not None:
target_attention_mask = target_attention_mask.to(torch.bool)
if target_attention_mask.ndim == 2:
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
if source_attention_mask is not None:
source_attention_mask = source_attention_mask.to(torch.bool)
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
return self.norm(self.out_proj(x))
class Anima(MiniTrainDIT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids):
if text_ids is not None:
return self.llm_adapter(text_embeds, text_ids)
else:
return text_embeds

View File

@@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
bad_keys = tuple(
k
for k, v in overrides.items()
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
)
if bad_keys:
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"

View File

@@ -13,7 +13,6 @@ from torchvision import transforms
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
import comfy.ldm.common_dit
def apply_rotary_pos_emb(
t: torch.Tensor,
@@ -836,8 +835,6 @@ class MiniTrainDIT(nn.Module):
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
orig_shape = list(x.shape)
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
x_B_C_T_H_W = x
timesteps_B_T = timesteps
crossattn_emb = context
@@ -885,5 +882,5 @@ class MiniTrainDIT(nn.Module):
)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -4,7 +4,6 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import logging
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
@@ -14,6 +13,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
@@ -28,20 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
try:
import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@@ -3,8 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import comfy.model_management
import comfy.model_patcher
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
@@ -103,20 +102,20 @@ UPSAMPLERS = {
class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = comfy.model_management.vae_device()
offload_device = comfy.model_management.vae_offload_device()
self.dtype = comfy.model_management.vae_dtype(self.load_device)
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()
def resample_latent(self, latent):
comfy.model_management.load_model_gpu(self.patcher)
model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))

View File

@@ -1,871 +0,0 @@
from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
LTXVModel,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
import comfy.ldm.common_dit
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = num_tokens
self.data = tensor
def expand(self):
"""Expand back to original tensor."""
if self.patches_per_frame == 1:
return self.data
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
return expanded.reshape(self.batch_size, -1, self.feature_dim)
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
"""Compute ada values on compressed per-frame data, then expand spatially."""
num_ada_params = scale_shift_table.shape[0]
# No compression - compute directly
if self.patches_per_frame == 1:
num_tokens = self.data.shape[1]
dim_per_param = self.feature_dim // num_ada_params
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
ada_values = (table_values + reshaped).unbind(dim=2)
return ada_values
# Compressed: compute on per-frame data then expand spatially
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
device=self.data.device, dtype=self.data.dtype
)
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
return tuple(
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
.reshape(batch_size, -1, frame_val.shape[-1])
for frame_val in frame_ada
)
class BasicAVTransformerBlock(nn.Module):
def __init__(
self,
v_dim,
a_dim,
v_heads,
a_heads,
vd_head,
ad_head,
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(
query_dim=v_dim,
heads=v_heads,
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn1 = CrossAttention(
query_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.attn2 = CrossAttention(
query_dim=v_dim,
context_dim=v_context_dim,
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn2 = CrossAttention(
query_dim=a_dim,
context_dim=a_context_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Video, K,V: Audio
self.audio_to_video_attn = CrossAttention(
query_dim=v_dim,
context_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Audio, K,V: Video
self.video_to_audio_attn = CrossAttention(
query_dim=a_dim,
context_dim=v_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.ff = FeedForward(
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.audio_ff = FeedForward(
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_video = nn.Parameter(
torch.empty(5, v_dim, device=device, dtype=dtype)
)
def get_ada_values(
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
):
if isinstance(timestep, CompressedTimestep):
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
num_ada_params = scale_shift_table.shape[0]
ada_values = (
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
).unbind(dim=2)
return ada_values
def get_av_ca_ada_values(
self,
scale_shift_table: torch.Tensor,
batch_size: int,
scale_shift_timestep: torch.Tensor,
gate_timestep: torch.Tensor,
num_scale_shift_values: int = 4,
):
scale_shift_ada_values = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
)
gate_ada_values = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
)
return (*scale_shift_ada_values, *gate_ada_values)
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
vx, ax = x
run_ax = run_ax and ax.numel() > 0
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
# video
if run_vx:
# video self-attention
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
del norm_vx
# video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
# audio
if run_ax:
# audio self-attention
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
del ashift_msa, ascale_msa
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
del norm_ax
# audio cross-attention
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
# video - audio cross attention.
if run_a2v or run_v2a:
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
# audio to video cross attention
if run_a2v:
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
del vx_scaled, ax_scaled
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
vx.addcmul_(a2v_out, gate_out_a2v)
del gate_out_a2v, a2v_out
# video to audio cross attention
if run_v2a:
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
del ax_scaled, vx_scaled
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
ax.addcmul_(v2a_out, gate_out_v2a)
del gate_out_v2a, v2a_out
del vx_norm3, ax_norm3
# video feedforward
if run_vx:
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
del vshift_mlp, vscale_mlp
ff_out = self.ff(vx_scaled)
del vx_scaled
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
vx.addcmul_(ff_out, vgate_mlp)
del vgate_mlp, ff_out
# audio feedforward
if run_ax:
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
del ashift_mlp, ascale_mlp
ff_out = self.audio_ff(ax_scaled)
del ax_scaled
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
ax.addcmul_(ff_out, agate_mlp)
del agate_mlp, ff_out
return vx, ax
class LTXAVModel(LTXVModel):
"""LTXAV model for audio-video generation."""
def __init__(
self,
in_channels=128,
audio_in_channels=128,
cross_attention_dim=4096,
audio_cross_attention_dim=2048,
attention_head_dim=128,
audio_attention_head_dim=64,
num_attention_heads=32,
audio_num_attention_heads=32,
caption_channels=3840,
num_layers=48,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
audio_positional_embedding_max_pos=[20],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
# Store audio-specific parameters
self.audio_in_channels = audio_in_channels
self.audio_cross_attention_dim = audio_cross_attention_dim
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
self.audio_out_channels = audio_in_channels
# Audio-specific constants
self.num_audio_channels = 8
self.audio_frequency_bins = 16
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXAV-specific components."""
# Audio-specific projections
self.audio_patchify_proj = self.operations.Linear(
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
)
# Audio-specific AdaLN
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList(
[
BasicAVTransformerBlock(
v_dim=self.inner_dim,
a_dim=self.audio_inner_dim,
v_heads=self.num_attention_heads,
a_heads=self.audio_num_attention_heads,
vd_head=self.attention_head_dim,
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
for _ in range(self.num_layers)
]
)
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXAV."""
# Video output components
super()._init_output_components(device, dtype)
# Audio output components
self.audio_scale_shift_table = nn.Parameter(
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
)
self.audio_norm_out = self.operations.LayerNorm(
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.audio_proj_out = self.operations.Linear(
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
)
self.a_patchifier = AudioPatchifier(1, start_end=True)
def separate_audio_and_video_latents(self, x, audio_length):
"""Separate audio and video latents from combined input."""
# vx = x[:, : self.in_channels]
# ax = x[:, self.in_channels :]
#
# ax = ax.reshape(ax.shape[0], -1)
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
#
# ax = ax.reshape(
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
# )
vx = x[0]
ax = x[1] if len(x) > 1 else torch.zeros(
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
device=vx.device, dtype=vx.dtype
)
return vx, ax
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
if ax.numel() == 0:
return vx
else:
return [vx, ax]
"""Recombine audio and video latents for output."""
# if ax.device != vx.device or ax.dtype != vx.dtype:
# logging.warning("Audio and video latents are on different devices or dtypes.")
# ax = ax.to(device=vx.device, dtype=vx.dtype)
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
#
# ax = ax.reshape(ax.shape[0], -1)
# # pad to f x h x w of the video latents
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
# if target_shape is None:
# repetitions = math.ceil(ax.shape[-1] / divisor)
# else:
# repetitions = target_shape[1] - vx.shape[1]
# padded_len = repetitions * divisor
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
# return torch.cat([vx, ax], dim=1)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXAV - separate audio and video, then patchify."""
audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
has_spatial_mask = False
if denoise_mask is not None:
# check if any frame has spatial variation (inpainting)
for frame_idx in range(denoise_mask.shape[2]):
frame_mask = denoise_mask[0, 0, frame_idx]
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
has_spatial_mask = True
break
[vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs
)
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
timestep_flat = timestep_scaled.flatten()
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
]
a_timestep, a_embedded_timestep = self.audio_adaln_single(
a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
else:
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
return [v_timestep, a_timestep, cross_av_timestep_ss], [
v_embedded_timestep,
a_embedded_timestep,
]
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
return [v_context, a_context], attention_mask
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
v_pixel_coords = pixel_coords[0]
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
a_latent_coords = pixel_coords[1]
a_pe = self._precompute_freqs_cis(
a_latent_coords,
dim=self.audio_inner_dim,
out_dtype=x_dtype,
max_pos=self.audio_positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.audio_num_attention_heads,
)
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
max_pos = max(
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
)
v_pixel_coords = v_pixel_coords.to(torch.float32)
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
av_cross_video_freq_cis = self._precompute_freqs_cis(
v_pixel_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
av_cross_audio_freq_cis = self._precompute_freqs_cis(
a_latent_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
):
vx = x[0]
ax = x[1]
v_context = context[0]
a_context = context[1]
v_timestep = timestep[0]
a_timestep = timestep[1]
v_pe, av_cross_video_freq_cis = pe[0]
a_pe, av_cross_audio_freq_cis = pe[1]
(
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(
args["img"],
v_context=args["v_context"],
a_context=args["a_context"],
attention_mask=args["attention_mask"],
v_timestep=args["v_timestep"],
a_timestep=args["a_timestep"],
v_pe=args["v_pe"],
a_pe=args["a_pe"],
v_cross_pe=args["v_cross_pe"],
a_cross_pe=args["a_cross_pe"],
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
)
return out
out = blocks_replace[("double_block", i)](
{
"img": (vx, ax),
"v_context": v_context,
"a_context": a_context,
"attention_mask": attention_mask,
"v_timestep": v_timestep,
"a_timestep": a_timestep,
"v_pe": v_pe,
"a_pe": a_pe,
"v_cross_pe": av_cross_video_freq_cis,
"a_cross_pe": av_cross_audio_freq_cis,
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
vx, ax = out["img"]
else:
vx, ax = block(
(vx, ax),
v_context=v_context,
a_context=a_context,
attention_mask=attention_mask,
v_timestep=v_timestep,
a_timestep=a_timestep,
v_pe=v_pe,
a_pe=a_pe,
v_cross_pe=av_cross_video_freq_cis,
a_cross_pe=av_cross_audio_freq_cis,
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
vx = x[0]
ax = x[1]
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
# Process audio output
a_scale_shift_values = (
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
+ a_embedded_timestep[:, :, None]
)
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
ax = self.audio_norm_out(ax)
ax = ax * (1 + a_scale) + a_shift
ax = self.audio_proj_out(ax)
# Unpatchify audio
ax = self.a_patchifier.unpatchify(
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
)
# Recombine audio and video
original_shape = kwargs.get("av_orig_shape")
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
def forward(
self,
x,
timestep,
context,
attention_mask=None,
frame_rate=25,
transformer_options={},
keyframe_idxs=None,
**kwargs,
):
"""
Forward pass for LTXAV model.
Args:
x: Combined audio-video input tensor
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments including audio_length
Returns:
Combined audio-video output tensor
"""
# Handle timestep format
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
v_timestep, a_timestep = timestep
kwargs["a_timestep"] = a_timestep
timestep = v_timestep
else:
kwargs["a_timestep"] = timestep
# Call parent forward method
return super().forward(
x,
timestep,
context,
attention_mask,
frame_rate,
transformer_options,
keyframe_idxs,
**kwargs,
)

View File

@@ -1,305 +0,0 @@
import math
from typing import Optional
import comfy.ldm.common_dit
import torch
from comfy.ldm.lightricks.model import (
CrossAttention,
FeedForward,
generate_freq_grid_np,
interleaved_freqs_cis,
split_freqs_cis,
)
from torch import nn
class BasicTransformerBlock1D(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
qk_norm (`str`, *optional*, defaults to None):
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
"""
def __init__(
self,
dim,
n_heads,
d_head,
context_dim=None,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
context_dim=None,
dtype=dtype,
device=device,
operations=operations,
)
# 3. Feed-forward
self.ff = FeedForward(
dim,
dim_out=dim,
glu=True,
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Normalization Before Self-Attention
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
norm_hidden_states = norm_hidden_states.squeeze(1)
# 2. Self-Attention
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Normalization before Feed-Forward
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Embeddings1DConnector(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=128,
num_attention_heads=30,
num_layers=2,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
dtype=None,
device=None,
operations=None,
split_rope=False,
double_precision_rope=False,
**kwargs,
):
super().__init__()
self.dtype = dtype
self.out_channels = in_channels
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.split_rope = split_rope
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = nn.ModuleList(
[
BasicTransformerBlock1D(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(num_layers)
]
)
inner_dim = num_attention_heads * attention_head_dim
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = nn.Parameter(
torch.rand(
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
)
* 2.0
- 1.0
)
def get_fractional_positions(self, indices_grid):
fractional_positions = torch.stack(
[
indices_grid[:, i] / self.positional_embedding_max_pos[i]
for i in range(1)
],
dim=-1,
)
return fractional_positions
def precompute_freqs(self, indices_grid, spacing):
source_dtype = indices_grid.dtype
dtype = (
torch.float32
if source_dtype in (torch.bfloat16, torch.float16)
else source_dtype
)
fractional_positions = self.get_fractional_positions(indices_grid)
indices = (
generate_freq_grid_np(
self.positional_embedding_theta,
indices_grid.shape[1],
self.inner_dim,
)
if self.double_precision_rope
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
).to(device=fractional_positions.device)
if spacing == "exp_2":
freqs = (
(indices * fractional_positions.unsqueeze(-1))
.transpose(-1, -2)
.flatten(2)
)
else:
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
return freqs
def generate_freq_grid(self, spacing, dtype, device):
dim = self.inner_dim
theta = self.positional_embedding_theta
n_pos_dims = 1
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
start = 1
end = theta
if spacing == "exp":
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
indices = indices.to(dtype=dtype, device=device)
elif spacing == "exp_2":
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
indices = indices.to(dtype=dtype)
elif spacing == "linear":
indices = torch.linspace(
start, end, dim // n_elem, device=device, dtype=dtype
)
elif spacing == "sqrt":
indices = torch.linspace(
start**2, end**2, dim // n_elem, device=device, dtype=dtype
).sqrt()
indices = indices * math.pi / 2
return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing)
if self.split_rope:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(
freqs, pad_size, self.num_attention_heads
)
else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 1. Input
if self.num_learnable_registers:
num_registers_duplications = math.ceil(
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
)
learnable_registers = torch.tile(
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
)
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
if attention_mask is not None:
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
indices_grid = torch.arange(
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid)
# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):
hidden_states = block(
hidden_states, attention_mask=attention_mask, pe=freqs_cis
)
# 3. Output
# if self.output_scale is not None:
# hidden_states = hidden_states / self.output_scale
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
return hidden_states, attention_mask

View File

@@ -1,292 +0,0 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
)
return mapping[float(scale)]
class PixelShuffleND(nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x):
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
class BlurDownsample(nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int):
super().__init__()
assert dims in (2, 3)
assert stride >= 1 and isinstance(stride, int)
self.dims = dims
self.stride = stride
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (5,5)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
# x2d: (B, C, H, W)
B, C, H, W = x2d.shape
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
x2d = F.conv2d(
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
)
return x2d
if self.dims == 2:
return _apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = _apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
class SpatialRationalResampler(nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = nn.Conv2d(
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class ResBlock(nn.Module):
def __init__(
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.activation = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class LatentUpsampler(nn.Module):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = nn.GroupNorm(32, mid_channels)
self.initial_activation = nn.SiLU()
self.res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
if spatial_upsample and temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=self.spatial_scale
)
else:
self.upsampler = nn.Sequential(
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError(
"Either spatial_upsample or temporal_upsample must be True"
)
self.post_upsample_res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
x = x[:, :, 1:, :, :]
else:
if isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
@classmethod
def from_config(cls, config):
return cls(
in_channels=config.get("in_channels", 4),
mid_channels=config.get("mid_channels", 128),
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
dims=config.get("dims", 2),
spatial_upsample=config.get("spatial_upsample", True),
temporal_upsample=config.get("temporal_upsample", False),
spatial_scale=config.get("spatial_scale", 2.0),
rational_resampler=config.get("rational_resampler", False),
)
def config(self):
return {
"_class_name": "LatentUpsampler",
"in_channels": self.in_channels,
"mid_channels": self.mid_channels,
"num_blocks_per_stage": self.num_blocks_per_stage,
"dims": self.dims,
"spatial_upsample": self.spatial_upsample,
"temporal_upsample": self.temporal_upsample,
"spatial_scale": self.spatial_scale,
"rational_resampler": self.rational_resampler,
}

View File

@@ -1,47 +1,13 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import math
from typing import Dict, Optional, Tuple
from einops import rearrange
import numpy as np
import torch
from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
def _log_base(x, base):
return np.log(x) / np.log(base)
class LTXRopeType(str, Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
KEY = "rope_type"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.INTERLEAVED
return cls(kwargs.get(cls.KEY, default))
class LTXFrequenciesPrecision(str, Enum):
FLOAT32 = "float32"
FLOAT64 = "float64"
KEY = "frequencies_precision"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.FLOAT32
return cls(kwargs.get(cls.KEY, default))
from comfy.ldm.flux.math import apply_rope1
def get_timestep_embedding(
timesteps: torch.Tensor,
@@ -73,7 +39,9 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
@@ -105,9 +73,7 @@ class TimestepEmbedding(nn.Module):
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None,
device=None,
operations=None,
dtype=None, device=None, operations=None,
):
super().__init__()
@@ -124,9 +90,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
)
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
if post_act_fn is None:
self.post_act = None
@@ -175,22 +139,12 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(
self,
embedding_dim,
size_emb_dim,
use_additional_conditions: bool = False,
dtype=None,
device=None,
operations=None,
):
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
@@ -209,22 +163,15 @@ class AdaLayerNormSingle(nn.Module):
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
):
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim,
size_emb_dim=embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
dtype=dtype,
device=device,
operations=operations,
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
@@ -238,7 +185,6 @@ class AdaLayerNormSingle(nn.Module):
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
@@ -246,24 +192,18 @@ class PixArtAlphaTextProjection(nn.Module):
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
):
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
)
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
)
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear_1(caption)
@@ -282,68 +222,23 @@ class GELU_approx(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis):
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
return (
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
if split_pe else
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
)
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
def apply_split_rotary_emb(input_tensor, cos, sin):
needs_reshape = False
if input_tensor.ndim != 4 and cos.ndim == 4:
B, H, T, _ = cos.shape
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
needs_reshape = True
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
first_half_input = split_input[..., :1, :]
second_half_input = split_input[..., 1:, :]
output = split_input * cos.unsqueeze(-2)
first_half_output = output[..., :1, :]
second_half_output = output[..., 1:, :]
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
output = rearrange(output, "... d r -> ... (d r)")
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
@@ -359,11 +254,9 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@@ -373,8 +266,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@@ -384,34 +277,14 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
@@ -433,446 +306,116 @@ class BasicTransformerBlock(nn.Module):
return x
def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
fractional_positions = torch.stack(
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
axis=-1,
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
@functools.lru_cache(maxsize=5)
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
pow_indices = np.power(
theta,
np.linspace(
_log_base(start, theta),
_log_base(end, theta),
inner_dim // n_elem,
dtype=np.float64,
),
)
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
inner_dim // n_elem,
device=device,
dtype=torch.float32,
)
)
indices = indices.to(dtype=torch.float32)
indices = indices * math.pi / 2
return indices
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
if use_middle_indices_grid:
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid.shape) == 4:
indices_grid = indices_grid[..., 0]
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32
device = indices_grid.device
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = indices.to(device=fractional_positions.device)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
return freqs
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
def interleaved_freqs_cis(freqs, pad_size):
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq, sin_freq
# Pad if dim is not divisible by 6
if dim % 6 != 0:
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
def split_freqs_cis(freqs, pad_size, num_attention_heads):
cos_freq = freqs.cos()
sin_freq = freqs.sin()
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
return freqs_cis
# Reshape freqs to be compatible with multi-head attention
B , T, half_HD = cos_freq.shape
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freq, sin_freq
caption_channels=4096,
num_layers=28,
class LTXBaseModel(torch.nn.Module, ABC):
"""
Abstract base class for LTX models (Lightricks Transformer models).
This class defines the common interface and shared functionality for all LTX models,
including LTXV (video) and LTXAV (audio-video) variants.
"""
def __init__(
self,
in_channels: int,
cross_attention_dim: int,
attention_head_dim: int,
num_attention_heads: int,
caption_channels: int,
num_layers: int,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list = [20, 2048, 2048],
causal_temporal_positioning: bool = False,
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.vae_scale_factors = vae_scale_factors
self.use_middle_indices_grid = use_middle_indices_grid
self.dtype = dtype
self.in_channels = in_channels
self.cross_attention_dim = cross_attention_dim
self.attention_head_dim = attention_head_dim
self.num_attention_heads = num_attention_heads
self.caption_channels = caption_channels
self.num_layers = num_layers
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
self.freq_grid_generator = (
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
else generate_freq_grid_pytorch
)
self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
# Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
# Initialize common components
self._init_common_components(device, dtype)
# Initialize model-specific components
self._init_model_components(device, dtype, **kwargs)
# Initialize transformer blocks
self._init_transformer_blocks(device, dtype, **kwargs)
# Initialize output components
self._init_output_components(device, dtype)
def _init_common_components(self, device, dtype):
"""Initialize components common to all LTX models
- patchify_proj: Linear projection for patchifying input
- adaln_single: AdaLN layer for timestep embedding
- caption_projection: Linear projection for caption embedding
"""
self.patchify_proj = self.operations.Linear(
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
)
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize model-specific components. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_output_components(self, device, dtype):
"""Initialize output components. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input data. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output data. Must be implemented by subclasses."""
pass
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
self,
indices_grid,
dim,
out_dtype,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=False,
num_attention_heads=32,
):
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
if split_mode:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
else:
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
n_elem = 2 * indices_grid.shape[1]
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
"""Prepare positional embeddings."""
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
pe = self._precompute_freqs_cis(
fractional_coords,
dim=self.inner_dim,
out_dtype=x_dtype,
max_pos=self.positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return pe
def _prepare_attention_mask(self, attention_mask, x_dtype):
"""Prepare attention mask."""
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
) * torch.finfo(x_dtype).max
return attention_mask
def forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
),
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
def _forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Internal forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
if isinstance(x, list):
input_dtype = x[0].dtype
batch_size = x[0].shape[0]
else:
input_dtype = x.dtype
batch_size = x.shape[0]
# Process input
merged_args = {**transformer_options, **kwargs}
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Process transformer blocks
x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
)
# Process output
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
return x
class LTXVModel(LTXBaseModel):
"""LTXV model for video generation."""
def __init__(
self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXV."""
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.num_attention_heads,
self.attention_head_dim,
context_dim=self.cross_attention_dim,
dtype=dtype,
device=device,
operations=self.operations,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.num_layers)
for d in range(num_layers)
]
)
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXV."""
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = self.operations.LayerNorm(
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1, start_end=True)
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXV."""
additional_args = {"orig_shape": list(x.shape)}
x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
@@ -880,30 +423,44 @@ class LTXVModel(LTXBaseModel):
causal_fix=self.causal_temporal_positioning,
)
grid_mask = None
if keyframe_idxs is not None:
additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
additional_args.update({"grid_mask": grid_mask})
x = x[:, grid_mask, :]
pixel_coords = pixel_coords[:, :, grid_mask, ...]
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
x = self.patchify_proj(x)
return x, pixel_coords, additional_args
timestep = timestep * 1000.0
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
@@ -921,28 +478,16 @@ class LTXVModel(LTXBaseModel):
transformer_options=transformer_options,
)
return x
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output for LTXV."""
# Apply scale-shift modulation
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
x = x * (1 + scale) + shift
# Modulation
x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x)
if keyframe_idxs is not None:
grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
full_x[:, grid_mask, :] = x
x = full_x
# Unpatchify to restore original dimensions
orig_shape = kwargs["orig_shape"]
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],

View File

@@ -21,23 +21,20 @@ def latent_to_pixel_coords(
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
shape = [1] * latent_coords.ndim
shape[1] = -1
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
class Patchifier(ABC):
def __init__(self, patch_size: int, start_end: bool=False):
def __init__(self, patch_size: int):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
self.start_end = start_end
@abstractmethod
def patchify(
@@ -74,23 +71,11 @@ class Patchifier(ABC):
torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij",
)
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
latent_sample_coords_end = latent_sample_coords_start + delta
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_sample_coords_start = rearrange(
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = rearrange(
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
)
if self.start_end:
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_sample_coords_end = rearrange(
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
)
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
else:
latent_coords = latent_sample_coords_start
return latent_coords
@@ -130,61 +115,3 @@ class SymmetricPatchifier(Patchifier):
q=self._patch_size[2],
)
return latents
class AudioPatchifier(Patchifier):
def __init__(self, patch_size: int,
sample_rate=16000,
hop_length=160,
audio_latent_downsample_factor=4,
is_causal=True,
start_end=False,
shift = 0
):
super().__init__(patch_size, start_end=start_end)
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self.shift = shift
def copy_with_shift(self, shift):
return AudioPatchifier(
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
self.is_causal, self.start_end, shift
)
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
if self.is_causal:
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
return audio_mel_frame * self.hop_length / self.sample_rate
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# audio_latents: (batch, channels, time, freq)
b, _, t, _ = audio_latents.shape
audio_latents = rearrange(
audio_latents,
"b c t f -> b t (c f)",
)
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
if self.start_end:
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
else:
audio_latents_timings = audio_latents_start_timings
return audio_latents, audio_latents_timings
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
# audio_latents: (batch, time, freq * channels)
audio_latents = rearrange(
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
)
return audio_latents

View File

@@ -1,279 +0,0 @@
import json
from dataclasses import dataclass
import math
import torch
import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
LATENT_DOWNSAMPLE_FACTOR = 4
@dataclass(frozen=True)
class AudioVAEComponentConfig:
"""Container for model component configuration extracted from metadata."""
autoencoder: dict
vocoder: dict
@classmethod
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
raw_config = metadata["config"]
if isinstance(raw_config, str):
parsed_config = json.loads(raw_config)
else:
parsed_config = raw_config
audio_config = parsed_config.get("audio_vae")
vocoder_config = parsed_config.get("vocoder")
assert audio_config is not None, "Audio VAE config is required for audio VAE"
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout."""
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
self.patchifier = patchfier
self.statistics = statistics_processor
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
channels = latents.shape[1]
freq = latents.shape[3]
patched, _ = self.patchifier.patchify(latents)
normalized = self.statistics.normalize(patched)
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
channels = latents.shape[1]
freq = latents.shape[3]
patched, _ = self.patchifier.patchify(latents)
denormalized = self.statistics.un_normalize(patched)
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
class AudioPreprocessor:
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
self.target_sample_rate = target_sample_rate
self.mel_bins = mel_bins
self.mel_hop_length = mel_hop_length
self.n_fft = n_fft
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
if source_rate == self.target_sample_rate:
return waveform
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
def waveform_to_mel(
self, waveform: torch.Tensor, waveform_sample_rate: int, device
) -> torch.Tensor:
waveform = self.resample(waveform, waveform_sample_rate)
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_sample_rate,
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.mel_hop_length,
f_min=0.0,
f_max=self.target_sample_rate / 2.0,
n_mels=self.mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
).to(device)
mel = mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
return mel.permute(0, 1, 3, 2).contiguous()
class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, state_dict: dict, metadata: dict):
super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=autoencoder_config["sampling_rate"],
hop_length=autoencoder_config["mel_hop_length"],
is_causal=autoencoder_config["is_causal"],
),
self.autoencoder.per_channel_statistics,
)
self.preprocessor = AudioPreprocessor(
target_sample_rate=autoencoder_config["sampling_rate"],
mel_bins=autoencoder_config["mel_bins"],
mel_hop_length=autoencoder_config["mel_hop_length"],
n_fft=autoencoder_config["n_fft"],
)
self.device_manager = ModelDeviceManager(self)
def encode(self, audio: dict) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors."""
waveform = audio["waveform"]
waveform_sample_rate = audio["sample_rate"]
input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
if waveform.shape[1] == 1:
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
else:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device
)
latents = self.autoencoder.encode(mel_spec)
posterior = DiagonalGaussianDistribution(latents)
latent_mode = posterior.mode()
normalized = self.normalizer.normalize(latent_mode)
return normalized.to(input_device)
def decode(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform)
def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape
target_length = time * LATENT_DOWNSAMPLE_FACTOR
if self.autoencoder.causality_axis != CausalityAxis.NONE:
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
return (
batch,
self.autoencoder.decoder.out_ch,
target_length,
self.autoencoder.mel_bins,
)
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
audio_channels = self.autoencoder.decoder.out_ch
vocoder_input = mel_spec.transpose(2, 3)
if audio_channels == 1:
vocoder_input = vocoder_input.squeeze(1)
elif audio_channels != 2:
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
return self.vocoder(vocoder_input)
@property
def sample_rate(self) -> int:
return int(self.autoencoder.sampling_rate)
@property
def mel_hop_length(self) -> int:
return int(self.autoencoder.mel_hop_length)
@property
def mel_bins(self) -> int:
return int(self.autoencoder.mel_bins)
@property
def latent_channels(self) -> int:
return int(self.autoencoder.decoder.z_channels)
@property
def latent_frequency_bins(self) -> int:
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
@property
def latents_per_second(self) -> float:
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
@property
def output_sample_rate(self) -> int:
output_rate = getattr(self.vocoder, "output_sample_rate", None)
if output_rate is not None:
return int(output_rate)
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
if upsample_factor is None:
raise AttributeError(
"Vocoder is missing upsample_factor; cannot infer output sample rate"
)
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
def memory_required(self, input_shape):
return self.device_manager.patcher.model_size()

View File

@@ -1,909 +0,0 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional
from enum import Enum
from .pixel_norm import PixelNorm
import comfy.ops
import logging
ops = comfy.ops.disable_weight_init
class StringConvertibleEnum(Enum):
"""
Base enum class that provides string-to-enum conversion functionality.
This mixin adds a str_to_enum() class method that handles conversion from
strings, None, or existing enum instances with case-insensitive matching.
"""
@classmethod
def str_to_enum(cls, value):
"""
Convert a string, enum instance, or None to the appropriate enum member.
Args:
value: Can be an enum instance of this class, a string, or None
Returns:
Enum member of this class
Raises:
ValueError: If the value cannot be converted to a valid enum member
"""
# Already an enum instance of this class
if isinstance(value, cls):
return value
# None maps to NONE member if it exists
if value is None:
if hasattr(cls, "NONE"):
return cls.NONE
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
# String conversion (case-insensitive)
if isinstance(value, str):
value_lower = value.lower()
# Try to match against enum values
for member in cls:
# Handle members with None values
if member.value is None:
if value_lower == "none":
return member
# Handle members with string values
elif isinstance(member.value, str) and member.value.lower() == value_lower:
return member
# Build helpful error message with valid values
valid_values = []
for member in cls:
if member.value is None:
valid_values.append("none")
elif isinstance(member.value, str):
valid_values.append(member.value)
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
raise ValueError(
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
f"Expected string, None, or {cls.__name__} instance."
)
class AttentionType(StringConvertibleEnum):
"""Enum for specifying the attention mechanism type."""
VANILLA = "vanilla"
LINEAR = "linear"
NONE = "none"
class CausalityAxis(StringConvertibleEnum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
def Normalize(in_channels, *, num_groups=32, normtype="group"):
if normtype == "group":
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
elif normtype == "pixel":
return PixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {normtype}")
class CausalConv2d(nn.Module):
"""
A causal 2D convolution.
This layer ensures that the output at time `t` only depends on inputs
at time `t` and earlier. It achieves this by applying asymmetric padding
to the time dimension (width) before the convolution.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
):
super().__init__()
self.causality_axis = causality_axis
# Ensure kernel_size and dilation are tuples
kernel_size = nn.modules.utils._pair(kernel_size)
dilation = nn.modules.utils._pair(dilation)
# Calculate padding dimensions
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
match self.causality_axis:
case CausalityAxis.NONE:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.HEIGHT:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
case _:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
# The internal convolution layer uses no padding, as we handle it manually
self.conv = ops.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x):
# Apply causal padding before convolution
x = F.pad(x, self.padding)
return self.conv(x)
def make_conv2d(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=None,
dilation=1,
groups=1,
bias=True,
causality_axis: Optional[CausalityAxis] = None,
):
"""
Create a 2D convolution layer that can be either causal or non-causal.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Size of the convolution kernel
stride: Convolution stride
padding: Padding (if None, will be calculated based on causal flag)
dilation: Dilation rate
groups: Number of groups for grouped convolution
bias: Whether to use bias
causality_axis: Dimension along which to apply causality.
Returns:
Either a regular Conv2d or CausalConv2d layer
"""
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
else:
# For non-causal convolution, use symmetric padding if not specified
if padding is None:
if isinstance(kernel_size, int):
padding = kernel_size // 2
else:
padding = tuple(k // 2 for k in kernel_size)
return ops.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
# So the output elements rely on the following windows:
# 0: [-,-,0]
# 1: [-,0,0]
# 2: [0,0,1]
# 3: [0,1,1]
# 4: [1,1,2]
# 5: [1,2,2]
# Notice that the first and second elements in the output rely only on the first element in the input,
# while all other elements rely on two elements in the input.
# So we can drop the first element to undo the padding (rather than the last element).
# This is a no-op for non-causal convolutions.
match self.causality_axis:
case CausalityAxis.NONE:
pass # x remains unchanged
case CausalityAxis.HEIGHT:
x = x[:, :, 1:, :]
case CausalityAxis.WIDTH:
x = x[:, :, :, 1:]
case CausalityAxis.WIDTH_COMPATIBILITY:
pass # x remains unchanged
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
class Downsample(nn.Module):
"""
A downsampling layer that can use either a strided convolution
or average pooling. Supports standard and causal padding for the
convolutional mode.
"""
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
raise ValueError("causality is only supported when `with_conv=True`.")
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
# (pad_left, pad_right, pad_top, pad_bottom)
match self.causality_axis:
case CausalityAxis.NONE:
pad = (0, 1, 0, 1)
case CausalityAxis.WIDTH:
pad = (2, 0, 0, 1)
case CausalityAxis.HEIGHT:
pad = (0, 1, 2, 0)
case CausalityAxis.WIDTH_COMPATIBILITY:
pad = (1, 0, 0, 1)
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
norm_type="group",
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
):
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, normtype=norm_type)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, normtype=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels, norm_type="group"):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, normtype=norm_type)
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w).contiguous()
q = q.permute(0, 2, 1).contiguous() # b,hw,c
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w).contiguous()
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w).contiguous()
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
# Convert string to enum if needed
attn_type = AttentionType.str_to_enum(attn_type)
if attn_type != AttentionType.NONE:
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
else:
logging.info(f"making identity attention with {in_channels} in_channels")
match attn_type:
case AttentionType.VANILLA:
return AttnBlock(in_channels, norm_type=norm_type)
case AttentionType.NONE:
return nn.Identity(in_channels)
case AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
case _:
raise ValueError(f"Unknown attention type: {attn_type}")
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
attn_type="vanilla",
mid_block_add_attention=True,
norm_type="group",
causality_axis=CausalityAxis.WIDTH.value,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.z_channels = z_channels
self.double_z = double_z
self.norm_type = norm_type
# Convert string to enum if needed (for config loading)
causality_axis = CausalityAxis.str_to_enum(causality_axis)
self.attn_type = AttentionType.str_to_enum(attn_type)
# downsampling
self.conv_in = make_conv2d(
in_channels,
self.ch,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
self.non_linearity = nn.SiLU()
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
# end
self.norm_out = Normalize(block_in, normtype=self.norm_type)
self.conv_out = make_conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
def forward(self, x):
"""
Forward pass through the encoder.
Args:
x: Input tensor of shape [batch, channels, time, n_mels]
Returns:
Encoded latent representation
"""
feature_maps = [self.conv_in(x)]
# Process each resolution level (from high to low resolution)
for resolution_level in range(self.num_resolutions):
# Apply residual blocks at current resolution level
for block_idx in range(self.num_res_blocks):
# Apply ResNet block with optional timestep embedding
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
# Apply attention if configured for this resolution level
if len(self.down[resolution_level].attn) > 0:
current_features = self.down[resolution_level].attn[block_idx](current_features)
# Store processed features
feature_maps.append(current_features)
# Downsample spatial dimensions (except at the final resolution level)
if resolution_level != self.num_resolutions - 1:
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
feature_maps.append(downsampled_features)
# === MIDDLE PROCESSING PHASE ===
# Take the lowest resolution features for middle processing
bottleneck_features = feature_maps[-1]
# Apply first middle ResNet block
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
# Apply middle attention block
bottleneck_features = self.mid.attn_1(bottleneck_features)
# Apply second middle ResNet block
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
# === OUTPUT PHASE ===
# Normalize the bottleneck features
output_features = self.norm_out(bottleneck_features)
# Apply non-linearity (SiLU activation)
output_features = self.non_linearity(output_features)
# Final convolution to produce latent representation
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
return self.conv_out(output_features)
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
attn_type="vanilla",
mid_block_add_attention=True,
norm_type="group",
causality_axis=CausalityAxis.WIDTH.value,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = out_ch
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.norm_type = norm_type
self.z_channels = z_channels
# Convert string to enum if needed (for config loading)
causality_axis = CausalityAxis.str_to_enum(causality_axis)
self.attn_type = AttentionType.str_to_enum(attn_type)
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
self.non_linearity = nn.SiLU()
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, normtype=self.norm_type)
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
def _adjust_output_shape(self, decoded_output, target_shape):
"""
Adjust output shape to match target dimensions for variable-length audio.
This function handles the common case where decoded audio spectrograms need to be
resized to match a specific target shape.
Args:
decoded_output: Tensor of shape (batch, channels, time, frequency)
target_shape: Target shape tuple (batch, channels, time, frequency)
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, channels, time, frequency)
_, _, current_time, current_freq = decoded_output.shape
_, target_channels, target_time, target_freq = target_shape
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
# For audio: pad_left/right = frequency, pad_top/bottom = time
padding = (
0,
max(freq_padding_needed, 0), # frequency padding (left, right)
0,
max(time_padding_needed, 0), # time padding (top, bottom)
)
decoded_output = F.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
def get_config(self):
return {
"ch": self.ch,
"out_ch": self.out_ch,
"ch_mult": self.ch_mult,
"num_res_blocks": self.num_res_blocks,
"in_channels": self.in_channels,
"resolution": self.resolution,
"z_channels": self.z_channels,
}
def forward(self, latent_features, target_shape=None):
"""
Decode latent features back to audio spectrograms.
Args:
latent_features: Encoded latent representation of shape (batch, channels, height, width)
target_shape: Optional target output shape (batch, channels, time, frequency)
If provided, output will be cropped/padded to match this shape
Returns:
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
"""
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
# Transform latent features to decoder's internal feature dimension
hidden_features = self.conv_in(latent_features)
# Middle processing
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
# Upsampling
# Progressively increase spatial resolution from lowest to highest
for resolution_level in reversed(range(self.num_resolutions)):
# Apply residual blocks at current resolution level
for block_index in range(self.num_res_blocks + 1):
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
if len(self.up[resolution_level].attn) > 0:
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
if resolution_level != 0:
hidden_features = self.up[resolution_level].upsample(hidden_features)
# Output
if self.give_pre_end:
# Return intermediate features before final processing (for debugging/analysis)
decoded_output = hidden_features
else:
# Standard output path: normalize, activate, and convert to output channels
# Final normalization layer
hidden_features = self.norm_out(hidden_features)
# Apply SiLU (Swish) activation function
hidden_features = self.non_linearity(hidden_features)
# Final convolution to map to output channels (typically 2 for stereo audio)
decoded_output = self.conv_out(hidden_features)
# Optional tanh activation to bound output values to [-1, 1] range
if self.tanh_out:
decoded_output = torch.tanh(decoded_output)
# Adjust shape for audio data
if target_shape is not None:
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
return decoded_output
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
class CausalAudioAutoencoder(nn.Module):
def __init__(self, config=None):
super().__init__()
if config is None:
config = self._guess_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
self.encoder = Encoder(**encoder_config)
self.decoder = Decoder(**decoder_config)
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
}
},
}
return config
def get_config(self):
return {
"sampling_rate": self.sampling_rate,
"mel_bins": self.mel_bins,
"mel_hop_length": self.mel_hop_length,
"n_fft": self.n_fft,
"causality_axis": self.causality_axis.value,
"is_causal": self.is_causal,
}
def encode(self, x):
return self.encoder(x)
def decode(self, x, target_shape=None):
return self.decoder(x, target_shape=target_shape)

View File

@@ -1,11 +1,11 @@
from typing import Tuple, Union
import threading
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
@@ -42,34 +42,23 @@ class CausalConv3d(nn.Module):
padding_mode=spatial_padding_mode,
groups=groups,
)
self.temporal_cache_state={}
def forward(self, x, causal: bool = True):
tid = threading.get_ident()
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
if cached is None:
padding_length = self.time_kernel_size - 1
if not causal:
padding_length = padding_length // 2
if x.shape[2] == 0:
return x
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):

View File

@@ -1,5 +1,4 @@
from __future__ import annotations
import threading
import torch
from torch import nn
from functools import partial
@@ -7,35 +6,12 @@ import math
from einops import rearrange
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
if isinstance(m, CausalConv3d):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)
def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
def add_exchange_cache(dest, cache_in, new_input, dim=2):
if dest is not None:
if cache_in is not None:
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
lead_in_dest.add_(lead_in_source)
body, new_input = split2(new_input, dest.shape[dim], dim)
dest.add_(body)
return torch_cat_if_needed([cache_in, new_input], dim=dim)
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@@ -229,7 +205,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
@@ -278,22 +254,6 @@ class Encoder(nn.Module):
return sample
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module):
r"""
@@ -381,6 +341,18 @@ class Decoder(nn.Module):
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
@@ -456,9 +428,8 @@ class Decoder(nn.Module):
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward_orig(
def forward(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
@@ -466,7 +437,6 @@ class Decoder(nn.Module):
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = (
@@ -475,12 +445,24 @@ class Decoder(nn.Module):
else lambda x: x
)
timestep_shift_scale = None
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
@@ -501,62 +483,16 @@ class Decoder(nn.Module):
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
timestep_shift_scale = ada_values.unbind(dim=1)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
output = []
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
class UNetMidBlock3D(nn.Module):
"""
@@ -727,22 +663,8 @@ class DepthToSpaceUpsample(nn.Module):
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
tid = threading.get_ident()
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
y = self.conv(x, causal=causal)
y = rearrange(
y,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
y = y[:, :, 1:, :, :]
drop_first_conv = False
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
@@ -754,20 +676,21 @@ class DepthToSpaceUpsample(nn.Module):
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
drop_first_res = False
if y.shape[2] == 0:
y = None
cached = add_exchange_cache(y, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
else:
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
return y
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
@@ -884,8 +807,6 @@ class ResnetBlock3D(nn.Module):
torch.randn(4, in_channels) / in_channels**0.5
)
self.temporal_cache_state={}
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
@@ -959,12 +880,9 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor)
tid = threading.get_ident()
cached = self.temporal_cache_state.get(tid, None)
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
self.temporal_cache_state[tid] = cached
output_tensor = input_tensor + hidden_states
return hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):

View File

@@ -1,213 +0,0 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import numpy as np
ops = comfy.ops.disable_weight_init
LRELU_SLOPE = 0.1
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
]
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
]
)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
class Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
"""
def __init__(self, config=None):
super(Vocoder, self).__init__()
if config is None:
config = self.get_default_config()
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
stereo = config.get("stereo", True)
resblock = config.get("resblock", "1")
self.output_sample_rate = config.get("output_sample_rate")
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
out_channels = 2 if stereo else 1
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
def get_default_config(self):
"""Generate default configuration for the vocoder."""
config = {
"resblock_kernel_sizes": [3, 7, 11],
"upsample_rates": [6, 5, 2, 2, 2],
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_initial_channel": 1024,
"stereo": True,
"resblock": "1",
}
return config
def forward(self, x):
"""
Forward pass of the vocoder.
Args:
x: Input spectrogram tensor. Can be:
- 3D: (batch_size, channels, time_steps) for mono
- 4D: (batch_size, 2, channels, time_steps) for stereo
Returns:
Audio tensor of shape (batch_size, out_channels, audio_length)
"""
if x.dim() == 4: # stereo
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@@ -13,53 +13,10 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
def modulate(x, scale, timestep_zero_index=None):
if timestep_zero_index is None:
return x * (1 + scale.unsqueeze(1))
else:
scale = (1 + scale.unsqueeze(1))
actual_batch = scale.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= scale[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= scale[:actual_batch]
return x
def apply_gate(gate, x, timestep_zero_index=None):
if timestep_zero_index is None:
return gate * x
else:
actual_batch = gate.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= gate[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= gate[:actual_batch]
return x
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))
#############################################################################
# Core NextDiT Model #
@@ -301,7 +258,6 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
timestep_zero_index=None,
transformer_options={},
):
"""
@@ -320,18 +276,18 @@ class JointTransformerBlock(nn.Module):
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
))), timestep_zero_index=timestep_zero_index
))
)
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
))), timestep_zero_index=timestep_zero_index
modulate(self.ffn_norm1(x), scale_mlp),
))
)
else:
assert adaln_input is None
@@ -389,37 +345,13 @@ class FinalLayer(nn.Module):
),
)
def forward(self, x, c, timestep_zero_index=None):
def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
x = modulate(self.norm_final(x), scale)
x = self.linear(x)
return x
def pad_zimage(feats, pad_token, pad_tokens_multiple):
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = start_t
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
return x_pos_ids
class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
@@ -446,12 +378,10 @@ class NextDiT(nn.Module):
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
siglip_feat_dim=None,
image_model=None,
device=None,
dtype=None,
operations=None,
**kwargs,
) -> None:
super().__init__()
self.dtype = dtype
@@ -561,43 +491,7 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers)
]
)
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
siglip_feat_dim,
dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.siglip_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
@@ -636,168 +530,70 @@ class NextDiT(nn.Module):
imgs = torch.stack(imgs, dim=0)
return imgs
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
if cap_feats is not None:
cap_feats = self.cap_embedder(cap_feats)
cap_feats_len = cap_feats.shape[1]
if self.pad_tokens_multiple is not None:
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
else:
cap_feats_len = 0
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
embeds = (cap_feats,)
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
bsz = 1
pH = pW = self.patch_size
device = x.device
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
if (not omni) or self.siglip_embedder is None:
cap_feats_len = embeds[0].shape[1] + offset
embeds += (None,)
freqs_cis += (None,)
else:
cap_feats_len += offset
if siglip_feats is not None:
b, h, w, c = siglip_feats.shape
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
siglip_feats = self.siglip_embedder(siglip_feats)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
if self.siglip_pad_token is not None:
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
else:
if self.siglip_pad_token is not None:
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
if siglip_feats is None:
embeds += (None,)
freqs_cis += (None,)
else:
embeds += (siglip_feats,)
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
if self.pad_tokens_multiple is not None:
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
embeds += (x,)
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
def patchify_and_embed(
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = x.shape[0]
cap_mask = None # TODO?
main_siglip = None
orig_x = x
embeds = ([], [], [])
freqs_cis = ([], [], [])
leftover_cap = []
start_t = 0
omni = len(ref_latents) > 0
if omni:
for i, ref in enumerate(ref_latents):
if i < len(ref_contexts):
ref_con = ref_contexts[i]
else:
ref_con = None
if i < len(siglip_feats):
sig_feat = siglip_feats[i]
else:
sig_feat = None
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
for i, e in enumerate(out[0]):
if e is not None:
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
freqs_cis[i].append(out[1][i])
start_t = out[2]
leftover_cap = ref_contexts[len(ref_latents):]
H, W = x.shape[-2], x.shape[-1]
img_sizes = [(H, W)] * bsz
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
img_len = out[0][-1].shape[1]
cap_len = out[0][0].shape[1]
for i, e in enumerate(out[0]):
if e is not None:
e = comfy.utils.repeat_to_batch_size(e, bsz)
embeds[i].append(e)
freqs_cis[i].append(out[1][i])
start_t = out[2]
for cap in leftover_cap:
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
cap_len += out[0][0].shape[1]
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
freqs_cis[0].append(out[1][0])
start_t += out[2]
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
patches = transformer_options.get("patches", {})
# refine context
cap_feats = torch.cat(embeds[0], dim=1)
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
feats = (cap_feats,)
fc = (cap_freqs_cis,)
if omni and len(embeds[1]) > 0:
siglip_mask = None
siglip_feats_combined = torch.cat(embeds[1], dim=1)
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
if self.siglip_refiner is not None:
for layer in self.siglip_refiner:
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
feats += (siglip_feats_combined,)
fc += (siglip_feats_freqs_cis,)
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None
x = torch.cat(embeds[-1], dim=1)
fc_x = torch.cat(freqs_cis[-1], dim=1)
if omni:
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
else:
timestep_zero_index = None
x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat(feats + (x,), dim=1)
if timestep_zero_index is not None:
ind = padded_full_embed.shape[1] - x.shape[1]
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -807,11 +603,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -826,26 +618,25 @@ class NextDiT(nn.Module):
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
@@ -854,7 +645,8 @@ class NextDiT(nn.Module):
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
img = self.final_layer(img, adaln_input)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img

View File

@@ -30,13 +30,6 @@ except ImportError as e:
raise e
exit(-1)
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
SAGE_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
@@ -570,93 +563,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = out.reshape(b, -1, heads * dim_head)
return out
@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
B, H, L, D = q.shape
if H != heads:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
dim_head = inner_dim // heads
if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape
try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if skip_reshape:
if not skip_output_reshape:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
else:
if skip_output_reshape:
pass
else:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
@@ -744,8 +650,6 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():

View File

@@ -14,13 +14,10 @@ if model_management.xformers_enabled_vae():
import xformers.ops
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
return torch.cat(xl, dim)
elif len(xl) == 1:
return xl[0]
else:
return None
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim):
"""
@@ -397,8 +394,7 @@ class Model(nn.Module):
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn:
attn_type = "linear"
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
@@ -552,8 +548,7 @@ class Encoder(nn.Module):
conv3d=False, time_compress=None,
**ignore_kwargs):
super().__init__()
if use_linear_attn:
attn_type = "linear"
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)

View File

@@ -45,7 +45,7 @@ class LitEma(nn.Module):
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert key not in self.m_name2s_name
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
@@ -54,7 +54,7 @@ class LitEma(nn.Module):
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert key not in self.m_name2s_name
assert not key in self.m_name2s_name
def store(self, parameters):
"""

View File

@@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
@@ -72,19 +72,9 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations
)
self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb
@@ -170,14 +160,8 @@ class Attention(nn.Module):
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
@@ -336,11 +320,10 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None,
final_layer=True,
use_additional_t_cond=False,
dtype=None,
device=None,
operations=None,
@@ -351,14 +334,12 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype,
device=device,
operations=operations
@@ -380,9 +361,6 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
if self.default_ref_method == "index_timestep_zero":
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
@@ -392,33 +370,27 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def _forward(
self,
@@ -426,8 +398,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
additional_t_cond=None,
transformer_options={},
control=None,
**kwargs
@@ -436,9 +408,6 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
@@ -447,19 +416,14 @@ class QwenImageTransformer2DModel(nn.Module):
h = 0
w = 0
index = 0
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
ref_method = kwargs.get("ref_latents_method", "index")
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
@@ -489,7 +453,14 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
@@ -537,6 +508,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@@ -71,7 +71,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config):
if "target" not in config:
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":

View File

@@ -62,8 +62,6 @@ class WanSelfAttention(nn.Module):
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
patches = transformer_options.get("patches", {})
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
def qkv_fn_q(x):
@@ -88,10 +86,6 @@ class WanSelfAttention(nn.Module):
transformer_options=transformer_options,
)
if "attn1_patch" in patches:
for p in patches["attn1_patch"]:
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
x = self.o(x)
return x
@@ -231,8 +225,6 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
patches = transformer_options.get("patches", {})
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
@@ -250,11 +242,6 @@ class WanAttentionBlock(nn.Module):
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if "attn2_patch" in patches:
for p in patches["attn2_patch"]:
x = p({"x": x, "transformer_options": transformer_options})
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -501,7 +488,7 @@ class WanModel(torch.nn.Module):
self.blocks = nn.ModuleList([
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for i in range(num_layers)
for _ in range(num_layers)
])
# head
@@ -554,7 +541,6 @@ class WanModel(torch.nn.Module):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
# time embeddings
@@ -582,10 +568,7 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -752,7 +735,6 @@ class VaceWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
# time embeddings
@@ -781,10 +763,7 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -883,10 +862,7 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -1350,19 +1326,16 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
x = block(x, e=e0, freqs=freqs, context=context)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
@@ -1601,10 +1574,7 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@@ -523,10 +523,7 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@@ -1,500 +0,0 @@
import torch
from einops import rearrange, repeat
import comfy
from comfy.ldm.modules.attention import optimized_attention
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q.transpose(1, 2) * scale
B, H, x_seqlens, K = visual_q.shape
x_ref_attn_maps = []
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
for i in range(0, x_seqlens, chunk_size):
end_i = min(i + chunk_size, x_seqlens)
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
# Apply softmax
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
attn_chunk = (attn_chunk - attn_max).exp()
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
attn_chunk = attn_chunk / (attn_sum + 1e-8)
# Apply mask and sum
masked_attn = attn_chunk * ref_target_mask
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
del attn_chunk, masked_attn
# Average across heads
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
x_ref_attn_maps.append(x_ref_attnmap)
del visual_q, ref_k
return torch.cat(x_ref_attn_maps, dim=0)
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
"""Args:
query (torch.tensor): B M H K
key (torch.tensor): B M H K
shape (tuple): (N_t, N_h, N_w)
ref_target_masks: [B, N_h * N_w]
"""
N_t, N_h, N_w = shape
x_seqlens = N_h * N_w
ref_k = ref_k[:, :x_seqlens]
_, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
split_chunk = heads // split_num
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
ref_target_masks
)
x_ref_attn_maps += x_ref_attn_maps_perhead
return x_ref_attn_maps / split_num
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
source_min, source_max = source_range
new_min, new_max = target_range
normalized = (column - source_min) / (source_max - source_min + epsilon)
scaled = normalized * (new_max - new_min) + new_min
return scaled
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def get_audio_embeds(encoded_audio, audio_start, audio_end):
audio_embs = []
human_num = len(encoded_audio)
audio_frames = encoded_audio[0].shape[0]
indices = (torch.arange(4 + 1) - 2) * 1
for human_idx in range(human_num):
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
pad_len = audio_end - audio_frames
pad_shape = list(encoded_audio[human_idx].shape)
pad_shape[0] = pad_len
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
else:
encoded_audio_in = encoded_audio[human_idx]
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
audio_embs.append(audio_emb)
return torch.cat(audio_embs, dim=0)
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
first_frame_audio_emb_s = audio_embs[:, :1, ...]
latter_frame_audio_emb = audio_embs[:, 1:, ...]
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
middle_index = audio_proj.seq_len // 2
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
audio_emb = torch.cat(audio_emb.split(1), dim=2)
return audio_emb
class RotaryPositionalEmbedding1D(torch.nn.Module):
def __init__(self,
head_dim,
):
super().__init__()
self.head_dim = head_dim
self.base = 10000
def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
freqs = freqs.to(pos_indices.device)
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
def forward(self, x, pos_indices):
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
x_ = x.float()
freqs_cis = freqs_cis.float().to(x.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
x_ = (x_ * cos) + (rotate_half(x_) * sin)
return x_.type_as(x)
class SingleStreamAttention(torch.nn.Module):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.dim = dim
self.encoder_hidden_states_dim = encoder_hidden_states_dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
N_t, N_h, N_w = shape
expected_tokens = N_t * N_h * N_w
actual_tokens = x.shape[1]
x_extra = None
if actual_tokens != expected_tokens:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
B = x.shape[0]
S = N_h * N_w
x = x.view(B * N_t, S, self.dim)
# get q for hidden_state
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
kv = self.kv_linear(encoder_hidden_states)
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
x = optimized_attention(
q.transpose(1, 2),
encoder_k.transpose(1, 2),
encoder_v.transpose(1, 2),
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
# linear transform
x = self.proj(x.reshape(B * N_t, S, self.dim))
x = x.view(B, N_t * S, self.dim)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x
class SingleStreamMultiAttention(SingleStreamAttention):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
class_range: int = 24,
class_interval: int = 4,
device=None, dtype=None, operations=None
) -> None:
super().__init__(
dim=dim,
encoder_hidden_states_dim=encoder_hidden_states_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
device=device,
dtype=dtype,
operations=operations
)
# Rotary-embedding layout parameters
self.class_interval = class_interval
self.class_range = class_range
self.max_humans = self.class_range // self.class_interval
# Constant bucket used for background tokens
self.rope_bak = int(self.class_range // 2)
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
shape=None,
x_ref_attn_map=None
) -> torch.Tensor:
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
# Single-speaker fall-through
if human_num <= 1:
return super().forward(x, encoder_hidden_states, shape)
N_t, N_h, N_w = shape
x_extra = None
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# Query projection
B, N, C = x.shape
q = self.q_linear(x)
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Use `class_range` logic for 2 speakers
rope_h1 = (0, self.class_interval)
rope_h2 = (self.class_range - self.class_interval, self.class_range)
rope_bak = int(self.class_range // 2)
# Normalize and scale attention maps for each speaker
max_values = x_ref_attn_map.max(1).values[:, None, None]
min_values = x_ref_attn_map.min(1).values[:, None, None]
max_min_values = torch.cat([max_values, min_values], dim=2)
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
# Token-wise speaker dominance
max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack([human1, human2, back], dim=1)
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
# Apply rotary to Q
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
q = self.rope_1d(q, normalized_pos)
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Keys / Values
_, N_a, _ = encoder_hidden_states.shape
encoder_kv = self.kv_linear(encoder_hidden_states)
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
encoder_k, encoder_v = encoder_kv.unbind(0)
# Rotary for keys assign centre of each speaker bucket to its context tokens
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
encoder_k = self.rope_1d(encoder_k, encoder_pos)
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Final attention
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
x = optimized_attention(
q.transpose(1, 2),
encoder_k.transpose(1, 2),
encoder_v.transpose(1, 2),
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
# Linear projection
x = x.reshape(B, N, C)
x = self.proj(x)
# Restore original layout
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x
class MultiTalkAudioProjModel(torch.nn.Module):
def __init__(
self,
seq_len: int = 5,
seq_len_vf: int = 12,
blocks: int = 12,
channels: int = 768,
intermediate_dim: int = 512,
out_dim: int = 768,
context_tokens: int = 32,
device=None, dtype=None, operations=None
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels
self.input_dim_vf = seq_len_vf * blocks * channels
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.out_dim = out_dim
# define multiple linear layers
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
def forward(self, audio_embeds, audio_embeds_vf):
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
B, _, _, S, C = audio_embeds.shape
# process audio of first frame
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
# process audio of latter frame
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
# first projection
audio_embeds = torch.relu(self.proj1(audio_embeds))
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
batch_size_c, N_t, C_a = audio_embeds_c.shape
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
# second projection
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
# normalization and reshape
context_tokens = self.norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
return context_tokens
class WanMultiTalkAttentionBlock(torch.nn.Module):
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
super().__init__()
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
class MultiTalkGetAttnMapPatch:
def __init__(self, ref_target_masks=None):
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
transformer_options = kwargs.get("transformer_options", {})
x = kwargs["x"]
if self.ref_target_masks is not None:
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
transformer_options["x_ref_attn_map"] = x_ref_attn_map
return x
class MultiTalkCrossAttnPatch:
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
self.model_patch = model_patch
self.audio_scale = audio_scale
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
transformer_options = kwargs.get("transformer_options", {})
block_idx = transformer_options.get("block_index", None)
x = kwargs["x"]
if block_idx is None:
return torch.zeros_like(x)
audio_embeds = transformer_options.get("audio_embeds")
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
norm_x, audio_embeds.to(x.dtype),
shape=transformer_options["grid_sizes"],
x_ref_attn_map=x_ref_attn_map
)
x = x + x_audio * self.audio_scale
return x
def models(self):
return [self.model_patch]
class MultiTalkApplyModelWrapper:
def __init__(self, init_latents):
self.init_latents = init_latents
def __call__(self, executor, x, *args, **kwargs):
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
samples = executor(x, *args, **kwargs)
return samples
class InfiniteTalkOuterSampleWrapper:
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
self.motion_frames_latent = motion_frames_latent
self.model_patch = model_patch
self.is_extend = is_extend
def __call__(self, executor, *args, **kwargs):
model_patcher = executor.class_obj.model_patcher
model_options = executor.class_obj.model_options
process_latent_in = model_patcher.model.process_latent_in
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
if self.motion_frames_latent is not None:
wrappers = model_options["transformer_options"]["wrappers"]
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
# run the sampling process
result = executor(*args, **kwargs)
# insert motion frames before decoding
if self.is_extend:
overlap = self.motion_frames_latent.shape[2]
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
return result
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
if self.motion_frames_latent is not None:
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
return self

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -20,29 +20,22 @@ class CausalConv3d(ops.Conv3d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = 2 * self.padding[0]
self.padding = (0, self.padding[1], self.padding[2])
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None
if cache_x is None and x.shape[2] == 1:
#Fast path - the op will pad for use by truncating the weight
#and save math on a pile of zeros.
return super().forward(x, autopad="causal_zero")
if self._padding > 0:
padding_needed = self._padding
if cache_x is not None:
cache_x = cache_x.to(x.device)
padding_needed = max(0, padding_needed - cache_x.shape[2])
padding_shape = list(x.shape)
padding_shape[2] = padding_needed
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
del cache_x
x = F.pad(x, padding)
return super().forward(x)
@@ -234,7 +227,6 @@ class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
input_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@@ -253,7 +245,7 @@ class Encoder3d(nn.Module):
scale = 1.0
# init block
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
@@ -339,7 +331,6 @@ class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
output_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@@ -387,7 +378,7 @@ class Decoder3d(nn.Module):
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
@@ -458,7 +449,6 @@ class WanVAE(nn.Module):
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
image_channels=3,
dropout=0.0):
super().__init__()
self.dim = dim
@@ -470,21 +460,19 @@ class WanVAE(nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]
@@ -504,11 +492,10 @@ class WanVAE(nn.Module):
def decode(self, z):
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]

View File

@@ -260,7 +260,6 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
for k in sdk:
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
if k.endswith(".weight") and ".linear1." in k:
@@ -323,7 +322,6 @@ def model_lora_keys_unet(model, key_map={}):
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
key_map[key_lora] = to
if isinstance(model, comfy.model_base.Kandinsky5):
for k in sdk:
@@ -332,12 +330,6 @@ def model_lora_keys_unet(model, key_map={}):
key_map["{}".format(key_lora)] = k
key_map["transformer.{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.ACEStep15):
for k in sdk:
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
return key_map

View File

@@ -1,81 +0,0 @@
import math
import torch
from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype
def element_size(self):
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
return info.bits // 8
def numel(self):
return math.prod(self.shape)
def tensors_to_geometries(tensors, dtype=None):
geometries = []
for t in tensors:
if t is None or isinstance(t, QuantizedTensor):
geometries.append(t)
continue
tdtype = t.dtype
if hasattr(t, "_model_dtype"):
tdtype = t._model_dtype
if dtype is not None:
tdtype = dtype
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
return geometries
def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])
if isinstance(tensor, QuantizedTensor):
inner_tensors, _ = tensor.__tensor_flatten__()
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
if tensor is None:
return 0
size = tensor.numel() * tensor.element_size()
aligment_req = 1024
return (size + aligment_req - 1) // aligment_req * aligment_req
def interpret_gathered_like(tensors, gathered):
offset = 0
dest_views = []
if gathered.dim() != 1 or gathered.element_size() != 1:
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
for tensor in tensors:
if tensor is None:
dest_views.append(None)
continue
if isinstance(tensor, QuantizedTensor):
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
else:
templates = { "data": tensor }
actuals = {}
for attr, template in templates.items():
size = template.numel() * template.element_size()
if offset + size > gathered.numel():
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
offset += vram_aligned_size(template)
if isinstance(tensor, QuantizedTensor):
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
else:
dest_views.append(actuals["data"])
return dest_views
aimdo_allocator = None

View File

@@ -20,7 +20,6 @@ import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
@@ -49,8 +48,6 @@ import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.model_management
import comfy.patcher_extension
@@ -150,8 +147,6 @@ class BaseModel(torch.nn.Module):
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
@@ -302,7 +297,7 @@ class BaseModel(torch.nn.Module):
return out
def load_model_weights(self, sd, unet_prefix="", assign=False):
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
for k in keys:
@@ -310,7 +305,7 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
@@ -325,7 +320,7 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
@@ -333,7 +328,10 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
@@ -776,8 +774,8 @@ class StableAudio1(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d:
s = d[k]
@@ -948,7 +946,7 @@ class GenmoMochi(BaseModel):
class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -979,60 +977,6 @@ class LTXV(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class LTXAV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
audio_denoise_mask = None
if denoise_mask is not None and "latent_shapes" in kwargs:
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
if len(denoise_mask) > 1:
audio_denoise_mask = denoise_mask[1]
denoise_mask = denoise_mask[0]
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
if audio_denoise_mask is not None:
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
keyframe_idxs = kwargs.get("keyframe_idxs", None)
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
latent_shapes = kwargs.get("latent_shapes", None)
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
v_timestep = timestep
a_timestep = timestep
if denoise_mask is not None:
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
if audio_denoise_mask is not None:
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
return v_timestep, a_timestep
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@@ -1148,31 +1092,9 @@ class CosmosPredict2(BaseModel):
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Anima(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
t5xxl_ids = kwargs.get("t5xxl_ids", None)
t5xxl_weights = kwargs.get("t5xxl_weights", None)
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -1188,39 +1110,10 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
clip_text_pooled = kwargs["pooled_output"] # Newbie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni
if clip_vision_outputs is not None and len(clip_vision_outputs) > 0:
sigfeats = []
for clip_vision_output in clip_vision_outputs:
if clip_vision_output is not None:
image_size = clip_vision_output.image_sizes[0]
shape = clip_vision_output.last_hidden_state.shape
sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1]))
if len(sigfeats) > 0:
out['siglip_feats'] = comfy.conds.CONDList(sigfeats)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_contexts = kwargs.get("reference_latents_text_embeds", None)
if ref_contexts is not None:
out['ref_contexts'] = comfy.conds.CONDList(ref_contexts)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class WAN21(BaseModel):
@@ -1541,47 +1434,6 @@ class ACEStep(BaseModel):
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class ACEStep15(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
device = kwargs["device"]
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
if cross_attn is not None:
out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
if refer_audio is None or len(refer_audio) == 0:
refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
else:
refer_audio = refer_audio[-1]
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
audio_codes = kwargs.get("audio_codes", None)
if audio_codes is not None:
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
return out
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
@@ -1619,9 +1471,6 @@ class QwenImage(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

View File

@@ -237,8 +237,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["vec_in_dim"] = None
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
@@ -253,7 +251,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1)
dit_config["patch_size"] = 16
dit_config["nerf_hidden_size"] = 64
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
@@ -261,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
@@ -307,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
dit_config["image_model"] = "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
dit_config["attention_head_dim"] = shape[0] // 32
@@ -432,9 +430,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None: # NewBie
if ctd_weight is not None:
dit_config["clip_text_dim"] = ctd_weight.shape[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
@@ -444,15 +441,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
try:
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
except Exception:
pass
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
if sig_weight is not None:
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
return dit_config
@@ -554,8 +544,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
dit_config = {}
dit_config["image_model"] = "cosmos_predict2"
if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "anima"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
@@ -630,11 +618,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
@@ -655,11 +638,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["audio_model"] = "ace1.5"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@@ -19,21 +19,13 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import threading
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import importlib
import platform
import weakref
import gc
import os
from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.torch
import comfy_aimdo.model_vbar
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -341,42 +333,28 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
def aotriton_supported(gpu_arch):
path = torch.__path__[0]
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
if gpu_arch in gfx:
return True
if "{}x".format(gpu_arch[:-1]) in gfx:
return True
if "{}xx".format(gpu_arch[:-2]) in gfx:
return True
return False
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
@@ -475,7 +453,7 @@ def module_size(module):
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel:
@@ -586,15 +564,9 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:
def get_free_ram():
return psutil.virtual_memory().available
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@@ -606,7 +578,7 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
can_unload = []
@@ -621,23 +593,15 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
for x in sorted(can_unload):
i = x[-1]
memory_to_free = 1e32
ram_to_free = 1e32
memory_to_free = None
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_to_free = 0
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
if ram_to_free > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@@ -651,7 +615,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -672,10 +636,7 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
models_to_load = []
free_for_dynamic=True
for x in models:
if not x.is_dynamic():
free_for_dynamic = False
loaded_model = LoadedModel(x)
try:
loaded_model_index = current_loaded_models.index(loaded_model)
@@ -701,25 +662,19 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
total_memory_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
#want to do.
#FIXME: This should subtract off the to_load current pin consumption.
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
models_l = free_memory(minimum_memory_required, device)
logging.info("{} models unloaded.".format(len(models_l)))
for loaded_model in models_to_load:
@@ -747,26 +702,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
current_loaded_models.insert(0, loaded_model)
return
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
def load_model_gpu(model):
return load_models_gpu([model])
@@ -783,9 +718,6 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
@@ -803,11 +735,6 @@ def cleanup_models_gc():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
def cleanup_models():
to_delete = []
@@ -851,7 +778,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
@@ -1089,8 +1016,8 @@ NUM_STREAMS = 0
if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia and AMD
if is_nvidia() or is_amd():
# Enable by default on Nvidia
if is_nvidia():
NUM_STREAMS = 2
if args.disable_async_offload:
@@ -1110,51 +1037,6 @@ def current_stream(device):
return None
stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = nullcontext()
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None or cast_buffer.numel() < size:
if ref is LARGEST_CASTED_WEIGHT[0]:
#If there is one giant weight we do not want both streams to
#allocate a buffer for it. It's up to the caster to get the other
#offload stream in this corner case
return None
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
if size > LARGEST_CASTED_WEIGHT[1]:
LARGEST_CASTED_WEIGHT = (ref, size)
return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
STREAM_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS == 0:
@@ -1197,62 +1079,7 @@ def sync_stream(device, stream):
return
current_stream(device).wait_stream(stream)
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = nullcontext()
if stream is not None:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
with wf_context:
for tensor in tensors:
dest_view = dest_views.pop(0)
if tensor is None:
continue
dest_view.copy_(tensor, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
if hasattr(weight, "_v"):
#Unexpected usage patterns. There is no reason these don't work but they
#have no testing and no callers do this.
assert r is None
assert stream is None
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
if dtype is None:
dtype = weight._model_dtype
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
#inconsistent between loaded and offloaded weights. So force the double casting
#that would happen in regular flow to make offload deterministic.
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_buffer.copy_(weight, non_blocking=non_blocking)
weight = cast_buffer
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
@@ -1271,12 +1098,10 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
if r is None:
r = torch.empty_like(weight, dtype=dtype, device=device)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
if r is None:
r = torch.empty_like(weight, dtype=dtype, device=device)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
@@ -1296,17 +1121,7 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
@@ -1328,7 +1143,7 @@ def pin_memory(tensor):
if not tensor.is_contiguous():
return False
size = tensor.nbytes
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
@@ -1340,9 +1155,6 @@ def pin_memory(tensor):
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
else:
logging.warning("Pin error.")
discard_cuda_async_error()
return False
@@ -1355,7 +1167,7 @@ def unpin_memory(tensor):
return False
ptr = tensor.data_ptr()
size = tensor.nbytes
size = tensor.numel() * tensor.element_size()
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
@@ -1371,9 +1183,6 @@ def unpin_memory(tensor):
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
else:
logging.warning("Unpin error.")
discard_cuda_async_error()
return False
@@ -1676,16 +1485,6 @@ def supports_fp8_compute(device=None):
return True
def supports_nvfp4_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
@@ -1707,12 +1506,6 @@ def lora_compute_dtype(device):
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def synchronize():
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
@@ -1724,17 +1517,15 @@ def soft_empty_cache(force=False):
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
return torch.cuda.memory.memory_summary()
return ""
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException(Exception):
pass

View File

@@ -38,7 +38,19 @@ from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
import comfy_aimdo.model_vbar
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@@ -111,10 +123,6 @@ def move_weight_functions(m, device):
memory += f.move_to(device=device)
return memory
def string_to_seed(data):
logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils")
return comfy.utils.string_to_seed(data)
class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
@@ -161,11 +169,6 @@ def get_key_weight(model, key):
return weight, set_func, convert_func
def key_param_name_to_key(key, param):
if len(key) == 0:
return param
return "{}.{}".format(key, param)
class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
self.model = model
@@ -209,27 +212,6 @@ class MemoryCounter:
def decrement(self, used: int):
self.value -= used
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
class LazyCastingParam(torch.nn.Parameter):
def __new__(cls, model, key, tensor):
return super().__new__(cls, tensor)
def __init__(self, model, key, tensor):
self.model = model
self.key = key
@property
def device(self):
return CustomTorchDevice
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
#all weights getting garbage collected per-weight
def to(self, *args, **kwargs):
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@@ -287,9 +269,6 @@ class ModelPatcher:
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
def is_dynamic(self):
return False
def model_size(self):
if self.size > 0:
return self.size
@@ -305,9 +284,6 @@ class ModelPatcher:
def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
@@ -635,14 +611,14 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
weight, set_func, convert_func = get_key_weight(self.model, key)
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches:
return weight
return
weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup and not return_weight:
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
@@ -655,15 +631,13 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if return_weight:
return out_weight
elif inplace_update:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
@@ -680,7 +654,7 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False):
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
params = []
@@ -707,8 +681,7 @@ class ModelPatcher:
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -745,7 +718,6 @@ class ModelPatcher:
continue
cast_weight = self.force_cast_weights
m.comfy_force_cast_weights = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
@@ -800,7 +772,7 @@ class ModelPatcher:
continue
for param in params:
key = key_param_name_to_key(n, param)
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
if comfy.model_management.is_device_cuda(device_to):
@@ -816,14 +788,13 @@ class ModelPatcher:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device(key_param_name_to_key(n, param))
self.pin_weight_to_device("{}.{}".format(n, param))
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
if lowvram_counter > 0:
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@@ -922,7 +893,7 @@ class ModelPatcher:
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
move_weight = True
for param in params:
key = key_param_name_to_key(n, param)
key = "{}.{}".format(n, param)
bk = self.backup.get(key, None)
if bk is not None:
if not lowvram_possible:
@@ -973,7 +944,7 @@ class ModelPatcher:
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device(key_param_name_to_key(n, param))
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
@@ -1011,9 +982,6 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
def partially_unload_ram(self, ram_to_unload):
pass
def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device)
@@ -1347,10 +1315,10 @@ class ModelPatcher:
key, original_weights=original_weights)
del original_weights[key]
if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so
target_device = self.offload_device
@@ -1385,249 +1353,7 @@ class ModelPatcher:
self.unpatch_hooks()
self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model.diffusion_model.state_dict()
for k, v in unet_state_dict.items():
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
except:
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
continue
key = "diffusion_model." + k
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)
class ModelPatcherDynamic(ModelPatcher):
def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False):
if load_device is not None and comfy.model_management.is_device_cpu(load_device):
#reroute to default MP for CPUs
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
return super().__new__(cls)
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
#this is now way more dynamic and we dont support the same base model for both Dynamic
#and non-dynamic patchers.
if hasattr(self.model, "model_loaded_weight_memory"):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
assert load_device is not None
def is_dynamic(self):
return True
def _vbar_get(self, create=False):
if self.load_device == torch.device("cpu"):
return None
vbar = self.model.dynamic_vbars.get(self.load_device, None)
if create and vbar is None:
# x10. We dont know what model defined type casts we have in the vbar, but virtual address
# space is pretty free. This will cover someone casting an entire model from FP4 to FP32
# with some left over.
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index)
self.model.dynamic_vbars[self.load_device] = vbar
return vbar
def loaded_size(self):
vbar = self._vbar_get()
if vbar is None:
return 0
return vbar.loaded_size()
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
def pin_weight_to_device(self, key):
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
def unpin_weight(self, key):
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
def unpin_all_weights(self):
self.partially_unload_ram(1e32)
def memory_required(self, input_shape):
#Pad this significantly. We are trying to get away from precise estimates. This
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
#doesn't need to be forced at this stage. The only thing you could do would be patch
#it all on CPU which consumes huge RAM.
assert not force_patch_weights
#Full load doesn't make sense as we dont actually have any loader capability here and
#now.
assert not full_load
assert device_to == self.load_device
num_patches = 0
allocated_size = 0
with self.use_ejected():
self.unpatch_hooks()
vbar = self._vbar_get(create=True)
if vbar is not None:
vbar.prioritize()
#We have way more tools for acceleration on comfy weight offloading, so always
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
for x in loading:
_, _, _, n, m, params = x
def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
item._v_signature = None
def setup_param(self, m, n, param_key):
nonlocal num_patches
key = key_param_name_to_key(n, param_key)
weight_function = []
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return 0
if key in self.patches:
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
setattr(m, param_key + "_lowvram_function", None)
if key in self.weight_wrapper_patches:
weight_function.extend(self.weight_wrapper_patches[key])
setattr(m, param_key + "_function", weight_function)
geometry = weight
if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n
set_dirty(m, dirty)
v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else:
for param in params:
key = key_param_name_to_key(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
#These are all super dangerous. Who knows what the custom nodes actually do here...
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
self.apply_hooks(self.forced_hooks, force_apply=True)
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
assert not force_patch_weights #See above
assert self.load_device != torch.device("cpu")
vbar = self._vbar_get()
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of
#the control of proper model_managment. If you are a custom node author reading
#this, the correct pattern is to call load_models_gpu() to get a proper
#managed load of your model.
assert not load_weights
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
def unpatch_model(self, device_to=None, unpatch_weights=True):
super().unpatch_model(device_to=None, unpatch_weights=False)
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
self.unpatch_model(self.offload_device, unpatch_weights=False)
self.patch_model(load_weights=False)
try:
self.load(device_to, dirty=dirty)
except Exception as e:
self.detach()
raise e
#ModelPatcher::partially_load returns a number on what got loaded but
#nothing in core uses this and we have no data in the Dynamic world. Hit
#the custom node devs with a None rather than a 0 that would mislead any
#logic they might have.
return None
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
assert False #Should be unreachable - we dont ever cache in the new implementation
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
pass
CoreModelPatcher = ModelPatcher

View File

@@ -19,16 +19,10 @@
import torch
import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
import comfy.utils
import comfy_aimdo.model_vbar
import comfy_aimdo.torch
def run_every_op():
if torch.compiler.is_compiling():
@@ -78,122 +72,14 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
cast_dest = None
xfer_source = [ s.weight, s.bias ]
pin = comfy.pinned_memory.get_pin(s)
if pin is not None:
xfer_source = [ pin ]
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
if data is None:
continue
if data.dtype != geometry.dtype:
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
else:
pin = None
if pin is not None:
comfy.model_management.cast_to_gathered(xfer_source, pin)
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None:
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
orig = x
def to_dequant(tensor, dtype):
tensor = tensor.to(dtype=dtype)
if isinstance(tensor, QuantizedTensor):
tensor = tensor.dequantize()
return tensor
if orig.dtype != dtype or len(fns) > 0:
x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
else:
y = x
if update_weight:
orig.copy_(y)
for f in fns:
x = f(x)
return x
update_weight = signature is not None
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
if isinstance(input, QuantizedTensor):
dtype = input.params.orig_dtype
dtype = input._layout_params["orig_dtype"]
else:
dtype = input.dtype
if bias_dtype is None:
@@ -201,38 +87,22 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
bias = None
weight = None
if offload_stream is not None and not args.cuda_malloc:
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
if cast_buffer is None:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
weight = params[0]
bias = params[1]
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
@@ -240,7 +110,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_a = weight
if s.bias is not None:
bias = bias.to(dtype=bias_dtype)
for f in s.bias_function:
bias = f(bias)
@@ -262,20 +131,14 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
return
if device is None:
if weight_a is not None:
device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
if weight_a is not None:
device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
@@ -286,57 +149,6 @@ class CastWeightBiasOp:
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
super().__init__(in_features, out_features, bias, device, dtype)
return
# Issue is with `torch.empty` still reserving the full memory for the layer.
# Windows doesn't over-commit memory so without this, We are momentarily commit
# charged for the weight even though we might zero-copy it when we load the
# state dict. If the commit charge exceeds the ceiling we can destabilize the
# system.
torch.nn.Module.__init__(self)
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.bias = None
self.comfy_need_lazy_init_bias=bias
self.weight_comfy_model_dtype = dtype
self.bias_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k,v in state_dict.items():
if k[prefix_len:] == "weight":
if not assign_to_params_buffers:
v = v.clone()
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self):
return None
@@ -391,9 +203,7 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
if autopad == "causal_zero":
weight = weight[:, :, -input.shape[2]:, :, :]
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
@@ -402,15 +212,15 @@ class disable_weight_init:
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input, autopad=None):
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias, autopad=autopad)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -602,34 +412,26 @@ def fp8_linear(self, input):
return None
input_dtype = input.dtype
input_shape = input.shape
tensor_3d = input.ndim == 3
if tensor_3d:
input = input.reshape(-1, input_shape[2])
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
if input.ndim != 2:
return None
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input_fp8 = input.to(dtype).contiguous()
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream)
return o
uncast_bias_weight(self, w, bias, offload_stream)
if tensor_3d:
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
return o
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
@@ -675,20 +477,14 @@ if CUBLAS_IS_AVAILABLE:
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import (
QuantizedTensor,
QUANT_ALGOS,
TensorCoreFP8Layout,
get_layout_class,
)
from .quant_ops import QuantizedTensor, QUANT_ALGOS
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
@@ -701,33 +497,21 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self._has_bias = bias
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
def reset_parameters(self):
return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
@@ -736,8 +520,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
return
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
@@ -746,61 +529,49 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
# Load format-specific parameters
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
# FP8: single tensor scale
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
if scale is not None:
scale = scale.to(device)
layout_params = {
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
if scale is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
@@ -815,36 +586,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -852,36 +607,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def forward(self, input, *args, **kwargs):
run_every_op()
input_shape = input.shape
reshaped_3d = False
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype)
# Reshape output back to 3D if input was 3D
if reshaped_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
@@ -891,8 +622,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
@@ -919,17 +649,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if (
fp8_compute and

View File

@@ -1,29 +0,0 @@
import torch
import comfy.model_management
import comfy.memory_management
from comfy.cli_args import args
def get_pin(module):
return getattr(module, "_pin", None)
def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):
module._pin = pin
else:
module.pin_failed = True
return False
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.unpin_memory(module._pin)
del module._pin
return size

View File

@@ -1,174 +1,580 @@
import torch
import logging
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
)
_CK_AVAILABLE = True
if torch.version.cuda is None:
ck.registry.disable("cuda")
else:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
_CK_AVAILABLE = False
class QuantizedTensor:
pass
class _CKFp8Layout:
pass
class _CKNvfp4Layout:
pass
def register_layout_class(name, cls):
pass
def get_layout_class(name):
return None
from typing import Tuple, Dict
import comfy.float
# ==============================================================================
# FP8 Layouts with Comfy-Specific Extensions
# ==============================================================================
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
FP8_DTYPE = None # Must be overridden in subclass
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if cls.FP8_DTYPE is None:
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
def storage(self):
return self._qdata.storage()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
tensor_info = torch.finfo(tensor.dtype)
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
if scale is None:
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
if scale is not None:
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
else:
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
params = cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
block_scale=block_scale,
)
return qdata, params
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e5m2
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
# ==============================================================================
# Registry
# ==============================================================================
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
},
"float8_e5m2": {
"storage_t": torch.float8_e5m2,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
},
"nvfp4": {
"storage_t": torch.uint8,
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
"group_size": 16,
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
# ==============================================================================
# Re-exports for backward compatibility
# ==============================================================================
__all__ = [
"QuantizedTensor",
"QuantizedLayout",
"TensorCoreFP8Layout",
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"QUANT_ALGOS",
"register_layout_op",
]
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@@ -37,18 +37,12 @@ def prepare_noise(latent_image, seed, noise_inds=None):
return noises
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if torch.count_nonzero(latent_image) == 0:
if latent_format.latent_channels != latent_image.shape[1]:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image

View File

@@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
return executor.execute(model, noise_shape, conds, model_options=model_options)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
real_model = model.model
return real_model, conds, models

View File

@@ -9,6 +9,7 @@ if TYPE_CHECKING:
import torch
from functools import partial
import collections
from comfy import model_management
import math
import logging
import comfy.sampler_helpers
@@ -259,7 +260,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model.current_patcher.get_free_memory(x_in.device)
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
@@ -719,7 +720,7 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
@@ -983,6 +984,9 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
@@ -1009,24 +1013,6 @@ class CFGGuider:
else:
latent_shapes = [latent_image.shape]
if denoise_mask is not None:
if denoise_mask.is_nested:
denoise_masks = denoise_mask.unbind()
denoise_masks = denoise_masks[:len(latent_shapes)]
else:
denoise_masks = [denoise_mask]
for i in range(len(denoise_masks), len(latent_shapes)):
denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

View File

@@ -20,7 +20,6 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import comfy.weight_adapter
import yaml
import math
import os
@@ -56,10 +55,6 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.model_patcher
import comfy.lora
@@ -103,105 +98,6 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
return (new_modelpatcher, new_clip)
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
"""
Load LoRA in bypass mode without modifying base model weights.
Instead of patching weights, this injects the LoRA computation into the
forward pass: output = base_forward(x) + lora_path(x)
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
This is useful for training and when model weights are offloaded.
"""
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map)
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
# Separate adapters (for bypass) from other patches (for regular patching)
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
regular_patches = {} # diff, set, bias patches -> regular weight patching
for key, patch_data in loaded.items():
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
bypass_patches[key] = patch_data
else:
regular_patches[key] = patch_data
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
k = set()
k1 = set()
if model is not None:
new_modelpatcher = model.clone()
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
if regular_patches:
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
k.update(patched_keys)
# Apply adapter patches via bypass injection
manager = comfy.weight_adapter.BypassInjectionManager()
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in model_sd_keys:
manager.add_adapter(key, adapter, strength=strength_model)
k.add(key)
else:
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
injections = manager.create_injections(new_modelpatcher.model)
if manager.get_hook_count() > 0:
new_modelpatcher.set_injections("bypass_lora", injections)
else:
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
# Apply regular patches to clip
if regular_patches:
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
k1.update(patched_keys)
# Apply adapter patches via bypass injection
clip_manager = comfy.weight_adapter.BypassInjectionManager()
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in clip_sd_keys:
clip_manager.add_adapter(key, adapter, strength=strength_clip)
k1.add(key)
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
if clip_manager.get_hook_count() > 0:
new_clip.patcher.set_injections("bypass_lora", clip_injections)
else:
new_clip = None
for x in loaded:
if (x not in k) and (x not in k1):
patch_data = loaded[x]
patch_type = type(patch_data).__name__
if isinstance(patch_data, tuple):
patch_type = f"tuple({patch_data[0]})"
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
return (new_modelpatcher, new_clip)
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:
@@ -229,10 +125,8 @@ class CLIP:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@@ -322,7 +216,7 @@ class CLIP:
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
@@ -370,7 +264,7 @@ class CLIP:
if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
@@ -392,18 +286,8 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
can_assign = self.patcher.is_dynamic()
self.cond_stage_model.can_assign_sd = can_assign
# The CLIP models are a pretty complex web of wrappers and its
# a bit of an API change to plumb this all the way through.
# So spray paint the model with this flag that the loading
# nn.Module can then inspect for itself.
for m in self.cond_stage_model.modules():
m.can_assign_sd = can_assign
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
@@ -413,11 +297,8 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self, tokens={}):
memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
def load_model(self):
model_management.load_model_gpu(self.patcher)
return self.patcher
def get_key_patches(self):
@@ -440,7 +321,6 @@ class VAE:
self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
@@ -453,8 +333,6 @@ class VAE:
self.extra_1d_channel = None
self.crop_input = True
self.audio_sample_rate = 44100
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -552,27 +430,13 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
config = {}
param_key = None
self.upscale_ratio = 2048
self.downscale_ratio = 2048
if "decoder.layers.2.layers.1.weight_v" in sd:
param_key = "decoder.layers.2.layers.1.weight_v"
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1"
if param_key is not None:
if sd[param_key].shape[-1] == 12:
config["strides"] = [2, 4, 4, 6, 10]
self.audio_sample_rate = 48000
self.upscale_ratio = 1920
self.downscale_ratio = 1920
self.first_stage_model = AudioOobleckVAE(**config)
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
@@ -608,8 +472,8 @@ class VAE:
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (8, 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
@@ -682,9 +546,7 @@ class VAE:
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
@@ -720,7 +582,6 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
@@ -764,13 +625,14 @@ class VAE:
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2
if self.latent_channels == 48: # Wan 2.2
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
self.process_input = self.process_output = lambda image: image
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
@@ -793,7 +655,12 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
model_management.archive_model_dtypes(self.first_stage_model)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0:
logging.debug("Leftover VAE keys {}".format(u))
if device is None:
device = model_management.vae_device()
@@ -805,18 +672,7 @@ class VAE:
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
mp = comfy.model_patcher.CoreModelPatcher
if self.disable_offload:
mp = comfy.model_patcher.ModelPatcher
self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0:
logging.debug("Leftover VAE keys {}".format(u))
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
@@ -834,28 +690,17 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if self.crop_input:
downscale_ratio = self.spacial_compression_encode()
if not self.crop_input:
return pixels
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
downscale_ratio = self.spacial_compression_encode()
if pixels.shape[-1] > self.output_channels:
pixels = pixels[..., :self.output_channels]
elif pixels.shape[-1] < self.output_channels:
if self.pad_channel_value is not None:
if isinstance(self.pad_channel_value, str):
mode = self.pad_channel_value
value = None
else:
mode = "constant"
value = self.pad_channel_value
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
@@ -872,7 +717,7 @@ class VAE:
/ 3.0)
return output
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
@@ -931,7 +776,7 @@ class VAE:
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@@ -1005,7 +850,7 @@ class VAE:
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
@@ -1147,8 +992,6 @@ class CLIPType(Enum):
OVIS = 21
KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23
NEWBIE = 24
FLUX2 = 25
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -1179,10 +1022,6 @@ class TEModel(Enum):
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
QWEN3_2B = 17
GEMMA_3_12B = 18
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
def detect_te_model(sd):
@@ -1192,13 +1031,11 @@ def detect_te_model(sd):
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[0] == 10240:
if weight.shape[-1] == 4096:
return TEModel.T5_XXL
elif weight.shape[0] == 5120:
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
@@ -1208,8 +1045,6 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
@@ -1226,10 +1061,6 @@ def detect_te_model(sd):
return TEModel.QWEN3_4B
elif weight.shape[0] == 2048:
return TEModel.QWEN3_2B
elif weight.shape[0] == 4096:
return TEModel.QWEN3_8B
elif weight.shape[0] == 1024:
return TEModel.QWEN3_06B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@@ -1355,24 +1186,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer
else:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
elif te_model == TEModel.QWEN3_2B:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
elif te_model == TEModel.QWEN3_8B:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -1428,29 +1246,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
clip_data_gemma = clip_data[0]
clip_data_jina = clip_data[1]
else:
clip_data_gemma = clip_data[1]
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
elif clip_type == CLIPType.ACE:
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
if TEModel.QWEN3_4B in te_models:
model_type = "qwen3_4b"
else:
model_type = "qwen3_2b"
clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -1474,7 +1269,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
@@ -1562,8 +1357,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
model.load_model_weights(sd, diffusion_model_prefix)
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
@@ -1606,6 +1400,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
logging.debug("left over keys: {}".format(left_over))
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
@@ -1697,14 +1492,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
if not model_management.is_device_cpu(offload_device):
model.to(offload_device)
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in diffusion model: {}".format(left_over))
return model_patcher
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_diffusion_model(unet_path, model_options={}):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
@@ -1735,9 +1529,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if metadata is None:
metadata = {}
model_management.load_models_gpu(load_models)
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]

View File

@@ -155,8 +155,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
elif isinstance(layer_idx, list):
self.layer = layer_idx
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
@@ -299,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):
result = []
@@ -468,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, start_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@@ -481,15 +479,8 @@ class SDTokenizer:
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
if has_start_token:
if len(empty) > 0:
self.tokens_start = 1
self.start_token = empty[0]
else:
self.tokens_start = 0
self.start_token = start_token
if start_token is None:
logging.warning("WARNING: There's something wrong with your tokenizers.'")
self.tokens_start = 1
self.start_token = empty[0]
if end_token is not None:
self.end_token = end_token
else:
@@ -497,7 +488,7 @@ class SDTokenizer:
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = start_token
self.start_token = None
if end_token is not None:
self.end_token = end_token
else:
@@ -522,8 +513,6 @@ class SDTokenizer:
self.embedding_size = embedding_size
self.embedding_key = embedding_key
self.disable_weights = disable_weights
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
@@ -558,7 +547,7 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
if kwargs.get("disable_weights", self.disable_weights):
if kwargs.get("disable_weights", False):
parsed_weights = [(text, 1.0)]
else:
parsed_weights = token_weights(text, 1.0)

View File

@@ -23,14 +23,11 @@ import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE):
unet_config = {
@@ -765,31 +762,17 @@ class Flux2(Flux):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * (unet_config['hidden_size'] / 2604)
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0]
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
if len(detect) > 0:
detect["model_type"] = "qwen3_4b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
if len(detect) > 0:
detect["model_type"] = "qwen3_8b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref))
if len(detect) > 0:
if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict:
detect["pruned"] = True
return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect))
return None
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class GenmoMochi(supported_models_base.BASE):
unet_config = {
@@ -852,21 +835,6 @@ class LTXV(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
class LTXAV(LTXV):
unet_config = {
"image_model": "ltxav",
}
latent_format = latent_formats.LTXAV
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 0.077 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device)
return out
class HunyuanVideo(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan_video",
@@ -1008,36 +976,6 @@ class CosmosT2IPredict2(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class Anima(supported_models_base.BASE):
unet_config = {
"image_model": "anima",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Anima(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
@@ -1088,15 +1026,9 @@ class ZImage(Lumina2):
"shift": 3.0,
}
memory_usage_factor = 2.8
memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False):
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@@ -1597,46 +1529,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
class ACEStep15(supported_models_base.BASE):
unet_config = {
"audio_model": "ace1.5",
}
unet_extra_config = {
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
latent_format = comfy.latent_formats.ACEAudio15
memory_usage_factor = 4.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.ACEStep15(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
if "dtype_llama" in detect_2b:
detect = detect_2b
detect["lm_model"] = "qwen3_2b"
elif "dtype_llama" in detect_4b:
detect = detect_4b
detect["lm_model"] = "qwen3_4b"
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]

View File

@@ -112,8 +112,7 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
class TAEHV(nn.Module):
def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True),
latent_format=None, show_progress_bar=False):
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
super().__init__()
self.image_channels = 3
self.patch_size = 1
@@ -125,9 +124,6 @@ class TAEHV(nn.Module):
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2
elif self.latent_channels == 128: # LTX2
self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True)
if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1
@@ -135,54 +131,41 @@ class TAEHV(nn.Module):
self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func,
TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False),
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar
self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow))
self.frames_to_trim = self.t_upscale - 1
self._show_progress_bar = show_progress_bar
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if self.patch_size > 1:
B, T, C, H, W = x.shape
x = x.reshape(B * T, C, H, W)
x = F.pixel_unshuffle(x, self.patch_size)
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
if x.shape[1] % self.t_downscale != 0:
# pad at end to multiple of t_downscale
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
return self.process_out(x)
def decode(self, x, **kwargs):
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@@ -1,249 +0,0 @@
from .anima import Qwen3Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
import comfy.utils
def sample_manual_loop_no_classes(
model,
ids=None,
paddings=[],
execution_dtype=None,
cfg_scale: float = 2.0,
temperature: float = 0.85,
top_p: float = 0.9,
top_k: int = None,
seed: int = 1,
min_tokens: int = 1,
max_new_tokens: int = 2048,
audio_start_id: int = 151669, # The cutoff ID for audio codes
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
device = model.execution_device
if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
else:
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
output_audio_codes = []
past_key_values = []
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in range(max_new_tokens):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
cond_logits = next_token_logits[0:1]
uncond_logits = next_token_logits[1:2]
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
# Only generate audio tokens
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
top_k_vals, _ = torch.topk(cfg_logits, top_k)
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
cfg_logits[indices_to_remove] = remove_logit_value
if temperature > 0:
cfg_logits = cfg_logits / temperature
next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1)
else:
next_token = torch.argmax(cfg_logits, dim=-1)
token = next_token.item()
if token == eos_token_id:
break
embed, _, _, _ = model.process_tokens([[token]], device)
embeds = embed.repeat(2, 1, 1)
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
return output_audio_codes
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
cfg_scale = 2.0
positive = [[token for token, _ in inner_list] for inner_list in positive]
negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
negative = negative[0]
neg_pad = 0
if len(negative) < len(positive):
neg_pad = (len(positive) - len(negative))
negative = [model.special_tokens["pad"]] * neg_pad + negative
pos_pad = 0
if len(negative) > len(positive):
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
out = {}
lyrics = kwargs.get("lyrics", "")
bpm = kwargs.get("bpm", 120)
duration = kwargs.get("duration", 120)
keyscale = kwargs.get("keyscale", "C major")
timesignature = kwargs.get("timesignature", 2)
language = kwargs.get("language", "en")
seed = kwargs.get("seed", 0)
duration = math.ceil(duration)
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
return out
class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ACE15TEModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
super().__init__()
if dtype_llama is None:
dtype_llama = dtype
model = None
self.constant = 0.4375
if lm_model == "qwen3_4b":
model = Qwen3_4B_ACE15
self.constant = 0.5625
elif lm_model == "qwen3_2b":
model = Qwen3_2B_ACE15
self.lm_model = lm_model
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
if model is not None:
setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
self.dtypes = set([dtype, dtype_llama])
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_base = token_weight_pairs["qwen3_06b"]
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
self.qwen3_06b.set_clip_options({"layer": None})
base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base)
self.qwen3_06b.set_clip_options({"layer": [0]})
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
lm_metadata = token_weight_pairs["lm_metadata"]
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
def set_clip_options(self, options):
self.qwen3_06b.set_clip_options(options)
lm_model = getattr(self, self.lm_model, None)
if lm_model is not None:
lm_model.set_clip_options(options)
def reset_clip_options(self):
self.qwen3_06b.reset_clip_options()
lm_model = getattr(self, self.lm_model, None)
if lm_model is not None:
lm_model.reset_clip_options()
def load_sd(self, sd):
if "model.layers.0.post_attention_layernorm.weight" in sd:
shape = sd["model.layers.0.post_attention_layernorm.weight"].shape
if shape[0] == 1024:
return self.qwen3_06b.load_sd(sd)
else:
return getattr(self, self.lm_model).load_sd(sd)
def memory_estimation_function(self, token_weight_pairs, device=None):
lm_metadata = token_weight_pairs["lm_metadata"]
constant = self.constant
if comfy.model_management.should_use_bf16(device):
constant *= 0.5
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens += lm_metadata['min_tokens']
return num_tokens * constant * 1024 * 1024
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
class ACE15TEModel_(ACE15TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
return ACE15TEModel_

View File

@@ -1,61 +0,0 @@
from transformers import Qwen2Tokenizer, T5TokenizerFast
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
class AnimaTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
return self.t5xxl.untokenize(token_weight_pair)
def state_dict(self):
return {}
class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class AnimaTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out = super().encode_token_weights(token_weight_pairs)
out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int)
out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0])))
return out
def te(dtype_llama=None, llama_quantization_metadata=None):
class AnimaTEModel_(AnimaTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return AnimaTEModel_

View File

@@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype_t5 is not None:
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return CosmosTEModel_

View File

@@ -3,7 +3,7 @@ import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.llama
import comfy.model_management
from transformers import T5TokenizerFast, LlamaTokenizerFast, Qwen2Tokenizer
from transformers import T5TokenizerFast, LlamaTokenizerFast
import torch
import os
import json
@@ -118,7 +118,7 @@ class MistralTokenizerClass:
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}
@@ -172,60 +172,3 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
model_options["num_layers"] = 30
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Flux2TEModel_
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
class KleinTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
if name == "qwen3_4b":
tokenizer = Qwen3Tokenizer
elif name == "qwen3_8b":
tokenizer = Qwen3Tokenizer8B
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name, tokenizer=tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class KleinTokenizer8B(KleinTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_8b"):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name)
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen3_8BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_8B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def klein_te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3_4b"):
if model_type == "qwen3_4b":
model = Qwen3_4BModel
elif model_type == "qwen3_8b":
model = Qwen3_8BModel
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return Flux2TEModel_

View File

@@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype_t5 is not None:
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return MochiTEModel_

View File

@@ -10,11 +10,9 @@ import comfy.utils
def llama_detect(state_dict, prefix=""):
out = {}
norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)]
for norm_key in norm_keys:
if norm_key in state_dict:
out["dtype_llama"] = state_dict[norm_key].dtype
break
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:

View File

@@ -1,219 +0,0 @@
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
from dataclasses import dataclass
import torch
from torch import nn as nn
from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
vocab_size: int = 250002
type_vocab_size: int = 1
hidden_size: int = 1024
num_hidden_layers: int = 24
num_attention_heads: int = 16
rotary_emb_base: float = 20000.0
intermediate_size: int = 4096
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-05
bos_token_id: int = 0
eos_token_id: int = 2
pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
def forward(self, input_ids=None, embeddings=None):
if input_ids is not None and embeddings is None:
embeddings = self.word_embeddings(input_ids)
if embeddings is not None:
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, q, k):
batch, seqlen, heads, head_dim = q.shape
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
def rotate_half(x):
size = x.shape[-1] // 2
x1, x2 = x[..., :size], x[..., size:]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = embed_dim // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None, optimized_attention=None):
qkv = self.Wqkv(x)
batch_size, seq_len, _ = qkv.shape
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k = self.rotary_emb(q, k)
# NHD -> HND
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out)
class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
self.activation = F.gelu
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
def forward(self, hidden_states, mask=None, optimized_attention=None):
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
mlp_out = self.mlp(hidden_states)
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states
class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states
class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x)
x = self.emb_drop(x)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
sequence_output = self.encoder(x, attention_mask=mask)
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
pooled_output = None
if attention_mask is None:
pooled_output = sequence_output.mean(dim=1)
else:
attention_mask = attention_mask.to(sequence_output.dtype)
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
# Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.config = XLMRobertaConfig(**config_dict)
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
def get_input_embeddings(self):
return self.model.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.model.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@@ -1,15 +1,15 @@
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any, Tuple
from typing import Optional, Any
import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ops
import comfy.ldm.common_dit
import comfy.clip_model
import comfy.model_management
from . import qwen_vl
@dataclass
@@ -33,7 +33,6 @@ class Llama2Config:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Mistral3Small24BConfig:
@@ -56,7 +55,6 @@ class Mistral3Small24BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen25_3BConfig:
@@ -79,99 +77,6 @@ class Qwen25_3BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_06BConfig:
vocab_size: int = 151936
hidden_size: int = 1024
intermediate_size: int = 3072
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_06B_ACE15_Config:
vocab_size: int = 151669
hidden_size: int = 1024
intermediate_size: int = 3072
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_2B_ACE15_lm_Config:
vocab_size: int = 217204
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_4B_ACE15_lm_Config:
vocab_size: int = 217204
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_4BConfig:
@@ -194,30 +99,6 @@ class Qwen3_4BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_8BConfig:
vocab_size: int = 151936
hidden_size: int = 4096
intermediate_size: int = 12288
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Ovis25_2BConfig:
@@ -240,7 +121,6 @@ class Ovis25_2BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen25_7BVLI_Config:
@@ -263,7 +143,6 @@ class Qwen25_7BVLI_Config:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Gemma2_2B_Config:
@@ -287,7 +166,6 @@ class Gemma2_2B_Config:
sliding_attention = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Gemma3_4B_Config:
@@ -299,7 +177,7 @@ class Gemma3_4B_Config:
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [1000000.0, 10000.0]
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
@@ -308,36 +186,9 @@ class Gemma3_4B_Config:
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
final_norm: bool = True
lm_head: bool = False
@dataclass
class Gemma3_12B_Config:
vocab_size: int = 262208
hidden_size: int = 3840
intermediate_size: int = 15360
num_hidden_layers: int = 48
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -437,7 +288,6 @@ class Attention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
@@ -455,30 +305,11 @@ class Attention(nn.Module):
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
present_key_value = None
if past_key_value is not None:
index = 0
num_tokens = xk.shape[2]
if len(past_key_value) > 0:
past_key, past_value, index = past_key_value
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + xk.shape[2]] = xk
past_value[:, :, index:index + xv.shape[2]] = xv
xk = past_key[:, :, :index + xk.shape[2]]
xv = past_value[:, :, :index + xv.shape[2]]
present_key_value = (past_key, past_value, index + num_tokens)
else:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
else:
present_key_value = (xk, xv, index + num_tokens)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output), present_key_value
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@@ -509,17 +340,15 @@ class TransformerBlock(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x, present_key_value = self.self_attn(
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
)
x = residual + x
@@ -529,7 +358,7 @@ class TransformerBlock(nn.Module):
x = self.mlp(x)
x = residual + x
return x, present_key_value
return x
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
@@ -541,7 +370,7 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None:
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
@@ -554,17 +383,11 @@ class TransformerBlockGemma2(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
else:
attention_mask = sliding_mask
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
@@ -572,12 +395,11 @@ class TransformerBlockGemma2(nn.Module):
# Self Attention
residual = x
x = self.input_layernorm(x)
x, present_key_value = self.self_attn(
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
)
x = self.post_attention_layernorm(x)
@@ -590,7 +412,7 @@ class TransformerBlockGemma2(nn.Module):
x = self.post_feedforward_layernorm(x)
x = residual + x
return x, present_key_value
return x
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
@@ -621,10 +443,9 @@ class Llama2_(nn.Module):
else:
self.norm = None
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
if embeds is not None:
x = embeds
else:
@@ -633,13 +454,8 @@ class Llama2_(nn.Module):
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = past_key_values[0][2]
if position_ids is None:
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
@@ -650,16 +466,14 @@ class Llama2_(nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
if seq_len > 1:
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
@@ -675,27 +489,16 @@ class Llama2_(nn.Module):
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
next_key_values = []
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
past_kv = None
if past_key_values is not None:
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
x, current_kv = layer(
x = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_kv,
)
if current_kv is not None:
next_key_values.append(current_kv)
if i == intermediate_output:
intermediate = x.clone()
@@ -712,45 +515,7 @@ class Llama2_(nn.Module):
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)
if len(next_key_values) > 0:
return x, intermediate, next_key_values
else:
return x, intermediate
class Gemma3MultiModalProjector(torch.nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype)
)
self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"])
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype))
return projected_vision_outputs.type_as(vision_outputs)
return x, intermediate
class BaseLlama:
def get_input_embeddings(self):
@@ -762,21 +527,6 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
class Llama2(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
@@ -805,34 +555,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06B_ACE15_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -841,24 +564,6 @@ class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4B_ACE15_lm_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Ovis25_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@@ -928,21 +633,3 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_12B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.dtype = dtype
self.image_size = config.vision_config["image_size"]
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None

View File

@@ -1,11 +1,7 @@
from comfy import sd1_clip
import os
from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch
import comfy.utils
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -20,133 +16,3 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(text, return_word_ids)
embed_count = 0
for k in text_tokens:
tt = text_tokens[k]
for r in tt:
for i in range(len(r)):
if r[i][0] == 262144:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return text_tokens
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
self.gemma3_12b.set_clip_options(options)
def reset_clip_options(self):
self.gemma3_12b.reset_clip_options()
self.execution_device = None
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
missing_all = []
unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
missing_all.extend([f"{prefix}{k}" for k in missing])
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
return (missing_all, unexpected_all)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0
if comfy.model_management.should_use_bf16(device):
constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_

Some files were not shown because too many files have changed in this diff Show More