mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 11:10:03 +00:00
Compare commits
246 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59afc39848 | ||
|
|
028e17dd7a | ||
|
|
30c259cac8 | ||
|
|
1cb7e22a95 | ||
|
|
2640acb31c | ||
|
|
7dbd5dfe91 | ||
|
|
f8b981ae9a | ||
|
|
4967f81778 | ||
|
|
0a6746898d | ||
|
|
5151cff293 | ||
|
|
af96d9812d | ||
|
|
52a32e2b32 | ||
|
|
b907085709 | ||
|
|
065a2fbbec | ||
|
|
0ff0457892 | ||
|
|
6484ac89dc | ||
|
|
f55c98a89f | ||
|
|
ca7808f240 | ||
|
|
52e778fff3 | ||
|
|
9d8a817985 | ||
|
|
b59750a86a | ||
|
|
3f382a4f98 | ||
|
|
f17251bec6 | ||
|
|
c38e7d6599 | ||
|
|
eaf68c9b5b | ||
|
|
cc6a8dcd1a | ||
|
|
a2d60aad0f | ||
|
|
d8433c63fd | ||
|
|
dd41b74549 | ||
|
|
55f654db3d | ||
|
|
58c6ed541d | ||
|
|
234c3dc85f | ||
|
|
8908ee2628 | ||
|
|
1105e0d139 | ||
|
|
8938aa3f30 | ||
|
|
f16219e3aa | ||
|
|
8402c8700a | ||
|
|
58b8574661 | ||
|
|
90b3995ec8 | ||
|
|
bdb10a583f | ||
|
|
0e24dbb19f | ||
|
|
e9aae31fa2 | ||
|
|
0c18842acb | ||
|
|
d196a905bb | ||
|
|
18b79acba9 | ||
|
|
dff996ca39 | ||
|
|
828b1b9953 | ||
|
|
af81cb962d | ||
|
|
5c7b08ca58 | ||
|
|
6b573ae0cb | ||
|
|
015a0599d0 | ||
|
|
acfaa5c4a1 | ||
|
|
b6805429b9 | ||
|
|
25022e0b09 | ||
|
|
22a2644e57 | ||
|
|
b2ef58e2b1 | ||
|
|
6a6d456c88 | ||
|
|
3d1fdaf9f4 | ||
|
|
1286fcfe40 | ||
|
|
3bd71554a2 | ||
|
|
f66183a541 | ||
|
|
cbd68e3d58 | ||
|
|
d89c29f259 | ||
|
|
a9c35256bc | ||
|
|
532938b16b | ||
|
|
ecb683b057 | ||
|
|
c55fd74816 | ||
|
|
3398123752 | ||
|
|
943b3b615d | ||
|
|
10e90a5757 | ||
|
|
b75d349f25 | ||
|
|
7b8389578e | ||
|
|
9e00ce5b76 | ||
|
|
f5e66d5e47 | ||
|
|
87b0359392 | ||
|
|
cb96d4d18c | ||
|
|
394348f5ca | ||
|
|
7601e89255 | ||
|
|
6a1d3a1ae1 | ||
|
|
65ee24c978 | ||
|
|
17027f2a6a | ||
|
|
b5c8be8b1d | ||
|
|
24fdb92edf | ||
|
|
d526974576 | ||
|
|
e1ab6bb394 | ||
|
|
048f49adbd | ||
|
|
47bfd5a33f | ||
|
|
fdf49a2861 | ||
|
|
f41e5f398d | ||
|
|
27cbac865e | ||
|
|
3d0003c24c | ||
|
|
7d6103325e | ||
|
|
2d4a08b717 | ||
|
|
9a02382568 | ||
|
|
bd01d9f7fd | ||
|
|
443056c401 | ||
|
|
f60923590c | ||
|
|
1ef328c007 | ||
|
|
94c298f962 | ||
|
|
2fde9597f4 | ||
|
|
f91078b1ff | ||
|
|
3b3ef9a77a | ||
|
|
8b0b93df51 | ||
|
|
1c7eaeca10 | ||
|
|
18e7d6dba5 | ||
|
|
e1d85e7577 | ||
|
|
1199411747 | ||
|
|
5ebcab3c7d | ||
|
|
c350009236 | ||
|
|
dea899f221 | ||
|
|
e632e5de28 | ||
|
|
2abd2b5c20 | ||
|
|
a1a70362ca | ||
|
|
cf97b033ee | ||
|
|
eb1c42f649 | ||
|
|
e05c907126 | ||
|
|
09dc24c8a9 | ||
|
|
1d69245981 | ||
|
|
97f198e421 | ||
|
|
bda0eb2448 | ||
|
|
c4a6b389de | ||
|
|
4cd881866b | ||
|
|
265adad858 | ||
|
|
7f3e4d486c | ||
|
|
a389ee01bb | ||
|
|
9c71a66790 | ||
|
|
af4b7b5edb | ||
|
|
0f4ef3afa0 | ||
|
|
6b88478f9f | ||
|
|
e199c8cc67 | ||
|
|
0652cb8e2d | ||
|
|
958a17199a | ||
|
|
e974e554ca | ||
|
|
4e2110c794 | ||
|
|
e617cddf24 | ||
|
|
1f3f7a2823 | ||
|
|
88df172790 | ||
|
|
6d6a18b0b7 | ||
|
|
97ff9fae7e | ||
|
|
135fa49ec2 | ||
|
|
44869ff786 | ||
|
|
20182a393f | ||
|
|
5f109fe6a0 | ||
|
|
c58c13b2ba | ||
|
|
7f374e42c8 | ||
|
|
27d1bd8829 | ||
|
|
614cf9805e | ||
|
|
513b0c46fb | ||
|
|
dfac94695b | ||
|
|
163b629c70 | ||
|
|
998bf60beb | ||
|
|
906c089957 | ||
|
|
25de7b1bfa | ||
|
|
ab7ab5be23 | ||
|
|
ec4fc2a09a | ||
|
|
1a58087ac2 | ||
|
|
6c14f3afac | ||
|
|
e525673f72 | ||
|
|
3fa7a5c04a | ||
|
|
210f7a1ba5 | ||
|
|
d202c2ba74 | ||
|
|
8817f8fc14 | ||
|
|
22e40d2ace | ||
|
|
3bea4efc6b | ||
|
|
8cf2ba4ba6 | ||
|
|
b61a40cbc9 | ||
|
|
f2bb3230b7 | ||
|
|
614b8d3345 | ||
|
|
6abc30aae9 | ||
|
|
55bad30375 | ||
|
|
c305deed56 | ||
|
|
601ee1775a | ||
|
|
c170fd2db5 | ||
|
|
9d529e5308 | ||
|
|
f6bbc1ac84 | ||
|
|
098a352f13 | ||
|
|
e86b79ab9e | ||
|
|
426cde37f1 | ||
|
|
dd5af0c587 | ||
|
|
388b306a2b | ||
|
|
24188b3141 | ||
|
|
1bcda6df98 | ||
|
|
a1864c01f2 | ||
|
|
4739d7717f | ||
|
|
f13cff0be6 | ||
|
|
9cdc64998f | ||
|
|
560b1bdfca | ||
|
|
b7992f871a | ||
|
|
2c2aa409b0 | ||
|
|
a4787ac83b | ||
|
|
b5c59b763c | ||
|
|
b4f30bd408 | ||
|
|
dad076aee6 | ||
|
|
0cf33953a7 | ||
|
|
5b80addafd | ||
|
|
9da397ea2f | ||
|
|
92d97380bd | ||
|
|
99ce2a1f66 | ||
|
|
b1467da480 | ||
|
|
d8d60b5609 | ||
|
|
b1293d50ef | ||
|
|
19b466160c | ||
|
|
bc0ad9bb49 | ||
|
|
4054b4bf38 | ||
|
|
55ac7d333c | ||
|
|
afa8a24fe1 | ||
|
|
493b81e48f | ||
|
|
6b035bfce2 | ||
|
|
74b7f0b04b | ||
|
|
f72c6616b2 | ||
|
|
1c10b33f9b | ||
|
|
ddfce1af4f | ||
|
|
7a883849ea | ||
|
|
84867067ea | ||
|
|
3374e900d0 | ||
|
|
51696e3fdc | ||
|
|
dfff7e5332 | ||
|
|
e4ea393666 | ||
|
|
c8674bc6e9 | ||
|
|
3dfdcf66b6 | ||
|
|
95ca2e56c8 | ||
|
|
27ffd12c45 | ||
|
|
e693e4db6a | ||
|
|
d68ece7301 | ||
|
|
894837de9a | ||
|
|
fdc92863b6 | ||
|
|
a125cd84b0 | ||
|
|
84e9ce32c6 | ||
|
|
f43b8ab2a2 | ||
|
|
14d642acd6 | ||
|
|
aa895db7e8 | ||
|
|
cdfc25a160 | ||
|
|
81e4dac107 | ||
|
|
90853fb9cd | ||
|
|
f1dd6e50f8 | ||
|
|
fc0fbf141c | ||
|
|
f3d5d328a3 | ||
|
|
139addd53c | ||
|
|
cbee7d3390 | ||
|
|
6732014a0a | ||
|
|
989f715d92 | ||
|
|
2ba8d7cce8 | ||
|
|
51fb505ffa | ||
|
|
72c2071972 | ||
|
|
6e59934089 | ||
|
|
3e0eb8d33f |
@@ -1,5 +1,5 @@
|
||||
As of the time of writing this you need this preview driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
@@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
@@ -1,2 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -8,13 +8,15 @@ body:
|
||||
Before submitting a **Bug Report**, please ensure the following:
|
||||
|
||||
- **1:** You are running the latest version of ComfyUI.
|
||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||
`--disable-all-custom-nodes` command line argument.
|
||||
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||
|
||||
## API Node PR Checklist
|
||||
|
||||
### Scope
|
||||
- [ ] **Is API Node Change**
|
||||
|
||||
### Pricing & Billing
|
||||
- [ ] **Need pricing update**
|
||||
- [ ] **No pricing update**
|
||||
|
||||
If **Need pricing update**:
|
||||
- [ ] Metronome rate cards updated
|
||||
- [ ] Auto‑billing tests updated and passing
|
||||
|
||||
### QA
|
||||
- [ ] **QA done**
|
||||
- [ ] **QA not required**
|
||||
|
||||
### Comms
|
||||
- [ ] Informed **Kosinkadink**
|
||||
58
.github/workflows/api-node-template.yml
vendored
Normal file
58
.github/workflows/api-node-template.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Append API Node PR template
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, synchronize, ready_for_review]
|
||||
paths:
|
||||
- 'comfy_api_nodes/**' # only run if these files changed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
inject:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Ensure template exists and append to PR body
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { owner, repo } = context.repo;
|
||||
const number = context.payload.pull_request.number;
|
||||
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||
|
||||
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||
|
||||
let templateText;
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner,
|
||||
repo,
|
||||
path: templatePath,
|
||||
ref: pr.base.ref
|
||||
});
|
||||
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||
templateText = buf.toString('utf8');
|
||||
} catch (e) {
|
||||
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Enforce the presence of the marker inside the template (for idempotence)
|
||||
if (!templateText.includes(marker)) {
|
||||
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the PR already contains the marker, do not append again.
|
||||
const body = pr.body || '';
|
||||
if (body.includes(marker)) {
|
||||
core.info('Template already present in PR body; nothing to inject.');
|
||||
return;
|
||||
}
|
||||
|
||||
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||
core.notice('API Node template appended to PR description.');
|
||||
27
.github/workflows/release-stable-all.yml
vendored
27
.github/workflows/release-stable-all.yml
vendored
@@ -14,13 +14,13 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA Default (cu129)"
|
||||
name: "Release NVIDIA Default (cu130)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu129"
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "6"
|
||||
python_patch: "9"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
@@ -43,16 +43,33 @@ jobs:
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu126"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu126"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu126"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_amd_rocm:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 6.4.4"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm644"
|
||||
cache_tag: "rocm711"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
20
.github/workflows/test-ci.yml
vendored
20
.github/workflows/test-ci.yml
vendored
@@ -21,14 +21,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
@@ -73,14 +74,15 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
|
||||
@@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "130"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@@ -0,0 +1,168 @@
|
||||
# The Comfy guide to Quantization
|
||||
|
||||
|
||||
## How does quantization work?
|
||||
|
||||
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||
|
||||
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||
|
||||
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||
|
||||
```
|
||||
absmax = max(abs(tensor))
|
||||
scale = amax / max_dynamic_range_low_precision
|
||||
|
||||
# Quantization
|
||||
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||
|
||||
# De-Quantization
|
||||
tensor_dq = tensor_q.to(fp16) * scale
|
||||
|
||||
tensor_dq ~ tensor
|
||||
```
|
||||
|
||||
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||
|
||||
|
||||
## Quantization in Comfy
|
||||
|
||||
```
|
||||
QuantizedTensor (torch.Tensor subclass)
|
||||
↓ __torch_dispatch__
|
||||
Two-Level Registry (generic + layout handlers)
|
||||
↓
|
||||
MixedPrecisionOps + Metadata Detection
|
||||
```
|
||||
|
||||
### Representation
|
||||
|
||||
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||
|
||||
A `Layout` class defines how a specific quantization format behaves:
|
||||
- Required parameters
|
||||
- Quantize method
|
||||
- De-Quantize method
|
||||
|
||||
```python
|
||||
from comfy.quant_ops import QuantizedLayout
|
||||
|
||||
class MyLayout(QuantizedLayout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs):
|
||||
# Convert to quantized format
|
||||
qdata = ...
|
||||
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
return qdata.to(orig_dtype) * scale
|
||||
```
|
||||
|
||||
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||
|
||||
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||
```python
|
||||
from comfy.quant_ops import register_layout_op
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||
def my_linear(func, args, kwargs):
|
||||
# Extract tensors, call optimized kernel
|
||||
...
|
||||
```
|
||||
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||
|
||||
**Architecture:**
|
||||
|
||||
```python
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||
```
|
||||
|
||||
**Key mechanism:**
|
||||
|
||||
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||
- If the layer name **is** in `_layer_quant_config`:
|
||||
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||
- Load associated quantization parameters (scales, block_size, etc.)
|
||||
|
||||
**Why it's needed:**
|
||||
|
||||
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||
|
||||
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||
|
||||
|
||||
## Checkpoint Format
|
||||
|
||||
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||
|
||||
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||
|
||||
### Scaling Parameters details
|
||||
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||
- **weight_scale**: quantization scalers for the weights
|
||||
- **weight_scale_2**: global scalers in the context of double scaling
|
||||
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||
- **input_scale**: quantization scalers for the activations
|
||||
|
||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||
|
||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||
|
||||
### Quantization Metadata
|
||||
|
||||
The metadata stored alongside the checkpoint contains:
|
||||
- **format_version**: String to define a version of the standard
|
||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"_quantization_metadata": {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Creating Quantized Checkpoints
|
||||
|
||||
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||
|
||||
### Weight Quantization
|
||||
|
||||
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||
|
||||
### Calibration (for Activation Quantization)
|
||||
|
||||
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||
|
||||
1. **Collect statistics**: Run inference on N representative samples
|
||||
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
31
README.md
31
README.md
@@ -67,6 +67,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
@@ -112,10 +114,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
@@ -172,15 +175,19 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
@@ -197,7 +204,11 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -214,7 +225,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
@@ -235,7 +246,7 @@ RDNA 4 (RX 9000 series):
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
@@ -245,15 +256,11 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@ import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from typing import Dict, TypedDict, Optional
|
||||
from aiohttp import web
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
@@ -257,7 +258,54 @@ comfyui-frontend-package is not installed.
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
||||
"""Return a mapping of template asset names to their absolute paths."""
|
||||
try:
|
||||
from comfyui_workflow_templates import (
|
||||
get_asset_path,
|
||||
iter_templates,
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
template_entries = list(iter_templates())
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
||||
return None
|
||||
|
||||
asset_map: Dict[str, str] = {}
|
||||
try:
|
||||
for entry in template_entries:
|
||||
for asset in entry.assets:
|
||||
asset_map[asset.filename] = get_asset_path(
|
||||
entry.template_id, asset.filename
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
||||
return None
|
||||
|
||||
if not asset_map:
|
||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
||||
return None
|
||||
|
||||
return asset_map
|
||||
|
||||
|
||||
@classmethod
|
||||
def legacy_templates_path(cls) -> Optional[str]:
|
||||
"""Return the legacy templates directory shipped inside the meta package."""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@@ -276,6 +324,7 @@ comfyui-workflow-templates is not installed.
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
@@ -392,3 +441,17 @@ comfyui-workflow-templates is not installed.
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
assets = cls.template_asset_map()
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
||||
rel_path = request.match_info.get("path", "")
|
||||
target = assets.get(rel_path)
|
||||
if target is None:
|
||||
raise web.HTTPNotFound()
|
||||
return web.FileResponse(target)
|
||||
|
||||
return serve_template
|
||||
|
||||
112
app/subgraph_manager.py
Normal file
112
app/subgraph_manager.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import hashlib
|
||||
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
"""
|
||||
Source of subgraph - custom_nodes vs templates.
|
||||
"""
|
||||
path: str
|
||||
"""
|
||||
Relative path of the subgraph file.
|
||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||
"""
|
||||
name: str
|
||||
"""
|
||||
Name of subgraph file.
|
||||
"""
|
||||
info: CustomNodeSubgraphEntryInfo
|
||||
"""
|
||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||
"""
|
||||
data: str
|
||||
|
||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
node_pack: str
|
||||
"""Node pack name."""
|
||||
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||
if entry is None:
|
||||
return None
|
||||
entry = entry.copy()
|
||||
entry.pop('path', None)
|
||||
if remove_data:
|
||||
entry.pop('data', None)
|
||||
return entry
|
||||
|
||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||
entries = entries.copy()
|
||||
for key in list(entries.keys()):
|
||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
@@ -59,6 +59,9 @@ class UserManager():
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
user = request.headers["comfy-user"]
|
||||
# Block System Users (use same error message to prevent probing)
|
||||
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise KeyError("Unknown user: " + user)
|
||||
|
||||
if user not in self.users:
|
||||
raise KeyError("Unknown user: " + user)
|
||||
@@ -66,15 +69,16 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
root_dir = folder_paths.get_user_directory()
|
||||
else:
|
||||
raise KeyError("Unknown filepath type:" + type)
|
||||
|
||||
user = self.get_request_user_id(request)
|
||||
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
||||
user_root = folder_paths.get_public_user_directory(user)
|
||||
if user_root is None:
|
||||
return None
|
||||
path = user_root
|
||||
|
||||
# prevent leaving /{type}
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
@@ -101,7 +105,11 @@ class UserManager():
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("username not provided")
|
||||
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
||||
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = user_id + "_" + str(uuid.uuid4())
|
||||
|
||||
self.users[user_id] = name
|
||||
@@ -132,7 +140,10 @@ class UserManager():
|
||||
if username in self.users.values():
|
||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
||||
|
||||
user_id = self.add_user(username)
|
||||
try:
|
||||
user_id = self.add_user(username)
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata")
|
||||
@@ -424,7 +435,7 @@ class UserManager():
|
||||
return source
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
if not isinstance(dest, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
|
||||
@@ -413,7 +413,8 @@ class ControlNet(nn.Module):
|
||||
out_middle = []
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
if y is None:
|
||||
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
|
||||
@@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@@ -130,7 +131,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -145,7 +147,9 @@ class PerformanceFeature(enum.Enum):
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
@@ -157,7 +161,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
|
||||
@@ -310,11 +310,13 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
@@ -350,12 +352,13 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
|
||||
@@ -6,6 +6,7 @@ class LatentFormat:
|
||||
latent_dimensions = 2
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@@ -178,6 +179,54 @@ class Flux(SD3):
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
[0.0058, 0.0113, 0.0073],
|
||||
[0.0495, 0.0443, 0.0836],
|
||||
[-0.0099, 0.0096, 0.0644],
|
||||
[0.2144, 0.3009, 0.3652],
|
||||
[0.0166, -0.0039, -0.0054],
|
||||
[0.0157, 0.0103, -0.0160],
|
||||
[-0.0398, 0.0902, -0.0235],
|
||||
[-0.0052, 0.0095, 0.0109],
|
||||
[-0.3527, -0.2712, -0.1666],
|
||||
[-0.0301, -0.0356, -0.0180],
|
||||
[-0.0107, 0.0078, 0.0013],
|
||||
[0.0746, 0.0090, -0.0941],
|
||||
[0.0156, 0.0169, 0.0070],
|
||||
[-0.0034, -0.0040, -0.0114],
|
||||
[0.0032, 0.0181, 0.0080],
|
||||
[-0.0939, -0.0008, 0.0186],
|
||||
[0.0018, 0.0043, 0.0104],
|
||||
[0.0284, 0.0056, -0.0127],
|
||||
[-0.0024, -0.0022, -0.0030],
|
||||
[0.1207, -0.0026, 0.0065],
|
||||
[0.0128, 0.0101, 0.0142],
|
||||
[0.0137, -0.0072, -0.0007],
|
||||
[0.0095, 0.0092, -0.0059],
|
||||
[0.0000, -0.0077, -0.0049],
|
||||
[-0.0465, -0.0204, -0.0312],
|
||||
[0.0095, 0.0012, -0.0066],
|
||||
[0.0290, -0.0034, 0.0025],
|
||||
[0.0220, 0.0169, -0.0048],
|
||||
[-0.0332, -0.0457, -0.0468],
|
||||
[-0.0085, 0.0389, 0.0609],
|
||||
[-0.0076, 0.0003, -0.0043],
|
||||
[-0.0111, -0.0460, -0.0614],
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@@ -382,6 +431,7 @@ class HunyuanVideo(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
taesd_decoder_name = "taehv"
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
@@ -445,7 +495,7 @@ class Wan21(LatentFormat):
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
self.taesd_decoder_name = "lighttaew2_1"
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
@@ -516,6 +566,7 @@ class Wan22(Wan21):
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.taesd_decoder_name = "lighttaew2_2"
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
@@ -611,6 +662,67 @@ class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
def process_in(self, latent):
|
||||
out = latent * self.scale_factor
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
def process_out(self, latent):
|
||||
z = latent / self.scale_factor
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
return z
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors = [
|
||||
[ 0.0568, -0.0521, -0.0131],
|
||||
[ 0.0014, 0.0735, 0.0326],
|
||||
[ 0.0186, 0.0531, -0.0138],
|
||||
[-0.0031, 0.0051, 0.0288],
|
||||
[ 0.0110, 0.0556, 0.0432],
|
||||
[-0.0041, -0.0023, -0.0485],
|
||||
[ 0.0530, 0.0413, 0.0253],
|
||||
[ 0.0283, 0.0251, 0.0339],
|
||||
[ 0.0277, -0.0372, -0.0093],
|
||||
[ 0.0393, 0.0944, 0.1131],
|
||||
[ 0.0020, 0.0251, 0.0037],
|
||||
[-0.0017, 0.0012, 0.0234],
|
||||
[ 0.0468, 0.0436, 0.0203],
|
||||
[ 0.0354, 0.0439, -0.0233],
|
||||
[ 0.0090, 0.0123, 0.0346],
|
||||
[ 0.0382, 0.0029, 0.0217],
|
||||
[ 0.0261, -0.0300, 0.0030],
|
||||
[-0.0088, -0.0220, -0.0283],
|
||||
[-0.0272, -0.0121, -0.0363],
|
||||
[-0.0664, -0.0622, 0.0144],
|
||||
[ 0.0414, 0.0479, 0.0529],
|
||||
[ 0.0355, 0.0612, -0.0247],
|
||||
[ 0.0147, 0.0264, 0.0174],
|
||||
[ 0.0438, 0.0038, 0.0542],
|
||||
[ 0.0431, -0.0573, -0.0033],
|
||||
[-0.0162, -0.0211, -0.0406],
|
||||
[-0.0487, -0.0295, -0.0393],
|
||||
[ 0.0005, -0.0109, 0.0253],
|
||||
[ 0.0296, 0.0591, 0.0353],
|
||||
[ 0.0119, 0.0181, -0.0306],
|
||||
[-0.0085, -0.0362, 0.0229],
|
||||
[ 0.0005, -0.0106, 0.0242]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
# TODO: remove this in a few months
|
||||
SingleStreamBlock = None
|
||||
DoubleStreamBlock = None
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@@ -48,124 +48,6 @@ class Approximator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -11,12 +11,12 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
@@ -90,6 +90,7 @@ class Chroma(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -98,7 +99,7 @@ class Chroma(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
@@ -178,7 +179,10 @@ class Chroma(nn.Module):
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
@@ -221,7 +225,10 @@ class Chroma(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -10,12 +10,10 @@ from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
@@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
@@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
@@ -189,15 +188,15 @@ class ChromaRadiance(Chroma):
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
@@ -240,17 +239,8 @@ class ChromaRadiance(Chroma):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||
|
||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
|
||||
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
@@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
return embedding
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
@@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -98,11 +98,11 @@ class ModulationOut:
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
if vec.ndim == 2:
|
||||
@@ -129,77 +129,129 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor
|
||||
|
||||
|
||||
class SiLUActivation(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return self.gate_fn(x1) * x2
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.modulation = modulation
|
||||
|
||||
if self.modulation:
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
if mlp_silu_act:
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
del img_attn
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
del txt_attn
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
@@ -220,6 +272,9 @@ class SingleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
modulation=True,
|
||||
mlp_silu_act=False,
|
||||
bias=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@@ -231,30 +286,47 @@ class SingleStreamBlock(nn.Module):
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||
if mlp_silu_act:
|
||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||
self.mlp_act = SiLUActivation()
|
||||
else:
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.modulation = None
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
if self.modulation:
|
||||
mod, _ = self.modulation(vec)
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@@ -262,11 +334,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
|
||||
@@ -7,15 +7,8 @@ import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
q, k = apply_rope(q, k, pe)
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
@@ -15,6 +15,7 @@ from .layers import (
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -33,6 +34,11 @@ class FluxParams:
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
global_modulation: bool = False
|
||||
mlp_silu_act: bool = False
|
||||
ops_bias: bool = True
|
||||
default_ref_method: str = "offset"
|
||||
ref_index_scale: float = 1.0
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@@ -58,13 +64,17 @@ class Flux(nn.Module):
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -73,6 +83,9 @@ class Flux(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=params.global_modulation is False,
|
||||
mlp_silu_act=params.mlp_silu_act,
|
||||
proj_bias=params.ops_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -81,13 +94,30 @@ class Flux(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if params.global_modulation:
|
||||
self.double_stream_modulation_img = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.double_stream_modulation_txt = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.single_stream_modulation = Modulation(
|
||||
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@@ -103,9 +133,6 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -118,9 +145,17 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
@@ -136,7 +171,10 @@ class Flux(nn.Module):
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -177,7 +215,13 @@ class Flux(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -207,10 +251,10 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@@ -222,10 +266,22 @@ class Flux(nn.Module):
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
steps_h = h_len
|
||||
steps_w = w_len
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
index += rope_options.get("shift_t", 0.0)
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
@@ -241,16 +297,16 @@ class Flux(nn.Module):
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method == "index":
|
||||
index += 1
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
@@ -274,7 +330,11 @@ class Flux(nn.Module):
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
if len(self.params.axes_dim) == 4: # Flux 2
|
||||
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@@ -6,7 +6,6 @@ import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
@@ -42,6 +41,8 @@ class HunyuanVideoParams:
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@@ -157,7 +158,10 @@ class TokenRefiner(nn.Module):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
if x.dtype == torch.float16:
|
||||
c = x.float().sum(dim=1) / x.shape[1]
|
||||
else:
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
@@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -266,6 +274,18 @@ class HunyuanVideo(nn.Module):
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@@ -276,6 +296,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
clip_fea=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
@@ -331,12 +352,31 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.cond_type_embedding is not None:
|
||||
self.cond_type_embedding.to(txt.device)
|
||||
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
|
||||
else:
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
if clip_fea is not None:
|
||||
txt_vision_states = self.vision_in(clip_fea)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||
txt_vision_states = txt_vision_states + cond_emb
|
||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
@@ -349,7 +389,10 @@ class HunyuanVideo(nn.Module):
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -371,7 +414,10 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@@ -430,14 +476,14 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
@@ -445,5 +491,5 @@ class HunyuanVideo(nn.Module):
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
120
comfy/ldm/hunyuan_video/upsampler.py
Normal file
120
comfy/ldm/hunyuan_video/upsampler.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
||||
import model_management, model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
class SRModel3DV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 6,
|
||||
global_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
||||
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
||||
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
||||
self.global_residual = bool(global_residual)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
y = self.in_conv(x)
|
||||
for blk in self.blocks:
|
||||
y = blk(y)
|
||||
y = self.out_conv(y)
|
||||
if self.global_residual and (y.shape == residual.shape):
|
||||
y = y + residual
|
||||
return y
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: tuple[int, ...],
|
||||
num_res_blocks: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.block_out_channels = block_out_channels
|
||||
self.z_channels = z_channels
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_shortcut=False,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Args:
|
||||
z: (B, C, T, H, W)
|
||||
target_shape: (H, W)
|
||||
"""
|
||||
# z to block_in
|
||||
repeats = self.block_out_channels[0] // (self.z_channels)
|
||||
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||
|
||||
# upsampling
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
return out
|
||||
|
||||
UPSAMPLERS = {
|
||||
"720p": SRModel3DV2,
|
||||
"1080p": Upsampler,
|
||||
}
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
@@ -4,8 +4,40 @@ import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class NoPadConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
if isinstance(op, NoPadConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -14,7 +46,7 @@ class RMS_norm(nn.Module):
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
||||
@@ -27,11 +59,12 @@ class DnSmpl(nn.Module):
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = self.conv(x)
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
|
||||
if self.tds and self.refiner_vae:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
@@ -39,14 +72,7 @@ class DnSmpl(nn.Module):
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nf = frms // r1
|
||||
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@@ -54,34 +80,32 @@ class DnSmpl(nn.Module):
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
b, ci, frms, ht, wd = xn.shape
|
||||
nf = frms // r1
|
||||
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xn.shape
|
||||
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
if h.shape[2] == 0:
|
||||
return hf + xf
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = sc.shape
|
||||
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
return h + sc
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
@@ -94,11 +118,11 @@ class UpSmpl(nn.Module):
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = self.conv(x)
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tus and self.refiner_vae:
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
@@ -107,14 +131,7 @@ class UpSmpl(nn.Module):
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@@ -125,29 +142,43 @@ class UpSmpl(nn.Module):
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = xn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = sc.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
return h + sc
|
||||
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = x.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
class HunyuanRefinerResnetBlock(ResnetBlock):
|
||||
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
|
||||
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
h = x
|
||||
h = [ self.swish(self.norm1(x)) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
@@ -160,7 +191,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = NoPadConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@@ -175,10 +206,9 @@ class Encoder(nn.Module):
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
@@ -188,9 +218,9 @@ class Encoder(nn.Module):
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
@@ -201,31 +231,50 @@ class Encoder(nn.Module):
|
||||
if not self.refiner_vae and x.shape[2] == 1:
|
||||
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||
|
||||
x = self.conv_in(x)
|
||||
if self.refiner_vae:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.ffactor_temporal:
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
conv_carry_in = None
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
||||
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@@ -239,7 +288,7 @@ class Decoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = NoPadConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@@ -249,9 +298,9 @@ class Decoder(nn.Module):
|
||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
@@ -259,10 +308,9 @@ class Decoder(nn.Module):
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
@@ -275,27 +323,41 @@ class Decoder(nn.Module):
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
if self.refiner_vae:
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
if self.refiner_vae:
|
||||
x = torch.split(x, 2, dim=2)
|
||||
else:
|
||||
x = [ x ]
|
||||
out = []
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
out = out[:, :, -1:]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -3,12 +3,11 @@ from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@@ -238,20 +237,6 @@ class FeedForward(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
@@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
x.addcmul_(self.ff(y), gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
@@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
dtype = torch.float32
|
||||
device = indices_grid.device
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
# Compute frequencies and apply cos/sin
|
||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
# Pad if dim is not divisible by 6
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
padding_size = dim % 6
|
||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||
|
||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
|
||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||
freqs_cis = torch.stack([
|
||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
@@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = torch.addcmul(x, x, scale).add_(shift)
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
|
||||
@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
@@ -31,6 +32,7 @@ class JointAttention(nn.Module):
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
out_bias: bool = False,
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
@@ -59,7 +61,7 @@ class JointAttention(nn.Module):
|
||||
self.out = operation_settings.get("operations").Linear(
|
||||
n_heads * self.head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
bias=out_bias,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
@@ -70,35 +72,6 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
|
||||
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x_in.shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -134,8 +107,7 @@ class JointAttention(nn.Module):
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
@@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
z_image_modulation=False,
|
||||
attn_out_bias=False,
|
||||
operation_settings={},
|
||||
) -> None:
|
||||
"""
|
||||
@@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
hidden_dim=dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
operation_settings=operation_settings,
|
||||
@@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
if z_image_modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 256),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
|
||||
The final layer of NextDiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
||||
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
|
||||
super().__init__()
|
||||
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||
hidden_size,
|
||||
@@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
if z_image_modulation:
|
||||
min_mod = 256
|
||||
else:
|
||||
min_mod = 1024
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(hidden_size, 1024),
|
||||
min(hidden_size, min_mod),
|
||||
hidden_size,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
@@ -373,12 +363,16 @@ class NextDiT(nn.Module):
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: Optional[int] = None,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
ffn_dim_multiplier: float = 4.0,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = False,
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
rope_theta=10000.0,
|
||||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -390,6 +384,8 @@ class NextDiT(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.time_scale = time_scale
|
||||
self.pad_tokens_multiple = pad_tokens_multiple
|
||||
|
||||
self.x_embedder = operation_settings.get("operations").Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
@@ -411,6 +407,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=z_image_modulation,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
@@ -434,7 +431,7 @@ class NextDiT(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
@@ -457,18 +454,24 @@ class NextDiT(nn.Module):
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
z_image_modulation=z_image_modulation,
|
||||
attn_out_bias=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
@@ -503,96 +506,54 @@ class NextDiT(nn.Module):
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
|
||||
if cap_mask is not None:
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
else:
|
||||
l_effective_cap_len = [num_tokens] * bsz
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
|
||||
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
||||
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
max_seq_len = max(
|
||||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
||||
)
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
||||
|
||||
# build freqs_cis for cap and image individually
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
# cap_freqs_cis_shape[1] = max_cap_len
|
||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
for i in range(bsz):
|
||||
img = x[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_x.append(img)
|
||||
x = flat_x
|
||||
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
||||
for i in range(bsz):
|
||||
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
||||
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
||||
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
padded_img_mask = None
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
|
||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
@@ -615,7 +576,7 @@ class NextDiT(nn.Module):
|
||||
y: (N,) tensor of text tokens/features
|
||||
"""
|
||||
|
||||
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
import comfy.model_management
|
||||
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
'''
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import comfy.model_management
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
else:
|
||||
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/core.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def sinc(x: torch.Tensor):
|
||||
"""
|
||||
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||
"""
|
||||
return torch.where(x == 0,
|
||||
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x)
|
||||
|
||||
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = (kernel_size % 2 == 0)
|
||||
half_size = kernel_size // 2
|
||||
|
||||
#For kaiser window
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.:
|
||||
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||
else:
|
||||
beta = 0.
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||
if even:
|
||||
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||
# of the constant component in the input signal.
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = 'replicate',
|
||||
kernel_size: int = 12):
|
||||
# kernel_size should be even number for stylegan3 setup,
|
||||
# in this implementation, odd number is also possible.
|
||||
super().__init__()
|
||||
if cutoff < -0.:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = (kernel_size % 2 == 0)
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
#input [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||
mode=self.padding_mode)
|
||||
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||
stride=self.stride, groups=C)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left:-self.pad_right]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.lowpass(x)
|
||||
|
||||
return xx
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(self,
|
||||
activation,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12):
|
||||
super().__init__()
|
||||
self.up_ratio = up_ratio
|
||||
self.down_ratio = down_ratio
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
# x: [B,C,T]
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
|
||||
return x
|
||||
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
from .vae import VAE_16k
|
||||
from .bigvgan import BigVGANVocoder
|
||||
import logging
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
||||
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes, norm_fn):
|
||||
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
||||
return output
|
||||
|
||||
class MelConverter(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sampling_rate: float,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
norm_fn,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
# mel = librosa_mel_fn(sr=self.sampling_rate,
|
||||
# n_fft=self.n_fft,
|
||||
# n_mels=self.num_mels,
|
||||
# fmin=self.fmin,
|
||||
# fmax=self.fmax)
|
||||
# mel_basis = torch.from_numpy(mel).float()
|
||||
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
|
||||
hann_window = torch.hann_window(self.win_size)
|
||||
|
||||
self.register_buffer('mel_basis', mel_basis)
|
||||
self.register_buffer('hann_window', hann_window)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.mel_basis.device
|
||||
|
||||
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
|
||||
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
|
||||
|
||||
waveform = torch.nn.functional.pad(
|
||||
waveform.unsqueeze(1),
|
||||
[int((self.n_fft - self.hop_size) / 2),
|
||||
int((self.n_fft - self.hop_size) / 2)],
|
||||
mode='reflect')
|
||||
waveform = waveform.squeeze(1)
|
||||
|
||||
spec = torch.stft(waveform,
|
||||
self.n_fft,
|
||||
hop_length=self.hop_size,
|
||||
win_length=self.win_size,
|
||||
window=self.hann_window,
|
||||
center=center,
|
||||
pad_mode='reflect',
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True)
|
||||
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
spec = torch.matmul(self.mel_basis, spec)
|
||||
spec = spectral_normalize_torch(spec, self.norm_fn)
|
||||
|
||||
return spec
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# ckpt_path: str,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode == "16k", "Only 16k mode is supported currently."
|
||||
self.mel_converter = MelConverter(sampling_rate=16_000,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=8_000,
|
||||
norm_fn=torch.log10)
|
||||
|
||||
self.vae = VAE_16k().eval()
|
||||
|
||||
bigvgan_config = {
|
||||
"resblock": "1",
|
||||
"num_mels": 80,
|
||||
"upsample_rates": [4, 4, 2, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
],
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": True,
|
||||
}
|
||||
|
||||
self.vocoder = BigVGANVocoder(
|
||||
bigvgan_config
|
||||
).eval()
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
||||
# x: (B * L)
|
||||
mel = self.mel_converter(x)
|
||||
dist = self.vae.encode(mel)
|
||||
|
||||
return dist
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, z):
|
||||
mel_decoded = self.vae.decode(z)
|
||||
audio = self.vocoder(mel_decoded)
|
||||
|
||||
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
||||
return audio
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audio):
|
||||
audio = audio.mean(dim=1)
|
||||
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
||||
dist = self.encode_audio(audio)
|
||||
return dist.mean
|
||||
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from . import activations
|
||||
from .alias_free_torch import Activation1d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||
super(AMPBlock1, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]))
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AMPBlock2(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
||||
super(AMPBlock2, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BigVGANVocoder(torch.nn.Module):
|
||||
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||
def __init__(self, h):
|
||||
super().__init__()
|
||||
if isinstance(h, dict):
|
||||
h = SimpleNamespace(**h)
|
||||
self.h = h
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# pre conv
|
||||
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
|
||||
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||
|
||||
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList([
|
||||
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2)
|
||||
]))
|
||||
|
||||
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# post conv
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# pre conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
# upsampling
|
||||
for i_up in range(len(self.ups[i])):
|
||||
x = self.ups[i][i_up](x)
|
||||
# AMP blocks
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
# post conv
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
358
comfy/ldm/mmaudio/vae/vae.py
Normal file
358
comfy/ldm/mmaudio/vae/vae.py
Normal file
@@ -0,0 +1,358 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
Upsample1D, nonlinearity)
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
DATA_MEAN_80D = [
|
||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
|
||||
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
|
||||
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
||||
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
|
||||
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
|
||||
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
|
||||
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
|
||||
]
|
||||
|
||||
DATA_STD_80D = [
|
||||
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
|
||||
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
|
||||
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
|
||||
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
|
||||
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
|
||||
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
|
||||
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
|
||||
]
|
||||
|
||||
DATA_MEAN_128D = [
|
||||
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
|
||||
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
|
||||
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
|
||||
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
||||
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
|
||||
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
|
||||
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
|
||||
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
||||
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
|
||||
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
|
||||
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
|
||||
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
||||
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
|
||||
]
|
||||
|
||||
DATA_STD_128D = [
|
||||
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
|
||||
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
|
||||
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
|
||||
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
|
||||
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
|
||||
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
|
||||
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
|
||||
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
|
||||
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
|
||||
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
|
||||
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
|
||||
]
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_dim: int,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_dim == 80:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
||||
elif data_dim == 128:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
||||
|
||||
self.data_mean = self.data_mean.view(1, -1, 1)
|
||||
self.data_std = self.data_std.view(1, -1, 1)
|
||||
|
||||
self.encoder = Encoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
self.decoder = Decoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
out_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
|
||||
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
pass
|
||||
|
||||
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
|
||||
if normalize:
|
||||
x = self.normalize(x)
|
||||
moments = self.encoder(x)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
|
||||
dec = self.decoder(z)
|
||||
if unnormalize:
|
||||
dec = self.unnormalize(dec)
|
||||
return dec
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
|
||||
|
||||
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
rng: Optional[torch.Generator] = None,
|
||||
normalize: bool = True,
|
||||
unnormalize: bool = True,
|
||||
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||
|
||||
posterior = self.encode(x, normalize=normalize)
|
||||
if sample_posterior:
|
||||
z = posterior.sample(rng)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, unnormalize=unnormalize)
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def remove_weight_norm(self):
|
||||
return self
|
||||
|
||||
|
||||
class Encoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
double_z: bool = True,
|
||||
kernel_size: int = 3,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = down_layers
|
||||
self.attn_layers = attn_layers
|
||||
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
# downsampling
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_layers):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = dim * in_ch_mult[i_level]
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_out,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level in down_layers:
|
||||
down.downsample = Downsample1D(block_in, resamp_with_conv)
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
|
||||
# end
|
||||
self.conv_out = ops.Conv1d(block_in,
|
||||
2 * embed_dim if double_z else embed_dim,
|
||||
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_layers):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
if i_level in self.down_layers:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# end
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
kernel_size: int = 3,
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.ch = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = dim * ch_mult[self.num_layers - 1]
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level in self.down_layers:
|
||||
up.upsample = Upsample1D(block_in, resamp_with_conv)
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
if i_level in self.down_layers:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||
return h
|
||||
|
||||
|
||||
def VAE_16k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
|
||||
|
||||
|
||||
def VAE_44k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
|
||||
|
||||
|
||||
def get_my_vae(name: str, **kwargs) -> VAE:
|
||||
if name == '16k':
|
||||
return VAE_16k(**kwargs)
|
||||
if name == '44k':
|
||||
return VAE_44k(**kwargs)
|
||||
raise ValueError(f'Unknown model: {name}')
|
||||
|
||||
121
comfy/ldm/mmaudio/vae/vae_modules.py
Normal file
121
comfy/ldm/mmaudio/vae/vae_modules.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return torch.nn.functional.silu(x) / 0.596
|
||||
|
||||
def mp_sum(a, b, t=0.5):
|
||||
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
|
||||
|
||||
def normalize(x, dim=None, eps=1e-4):
|
||||
if dim is None:
|
||||
dim = list(range(1, x.ndim))
|
||||
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||
return x / norm.to(x.dtype)
|
||||
|
||||
class ResnetBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
out_dim = in_dim if out_dim is None else out_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.use_norm = use_norm
|
||||
|
||||
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
else:
|
||||
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# pixel norm
|
||||
if self.use_norm:
|
||||
x = normalize(x, dim=1)
|
||||
|
||||
h = x
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class AttnBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, num_heads=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
|
||||
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
y = self.qkv(h)
|
||||
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
|
||||
q, k, v = normalize(y, dim=1).unbind(2)
|
||||
|
||||
h = self.optimized_attention(q, k, v)
|
||||
h = self.proj_out(h)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv1(x)
|
||||
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
@@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
|
||||
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
from einops import rearrange
|
||||
import comfy.model_management
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = False):
|
||||
@@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if ddconfig.get("batch_norm_latent", False):
|
||||
self.bn_eps = 1e-4
|
||||
self.bn_momentum = 0.1
|
||||
self.ps = [2, 2]
|
||||
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
|
||||
eps=self.bn_eps,
|
||||
momentum=self.bn_momentum,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
)
|
||||
self.bn.eval()
|
||||
else:
|
||||
self.bn = None
|
||||
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
@@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
z, reg_log = self.regularization(z)
|
||||
|
||||
if self.bn is not None:
|
||||
z = rearrange(z,
|
||||
"... c (i pi) (j pj) -> ... (c pi pj) i j",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
|
||||
z = torch.nn.functional.batch_norm(z,
|
||||
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
|
||||
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
|
||||
momentum=self.bn_momentum,
|
||||
eps=self.bn_eps)
|
||||
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.bn is not None:
|
||||
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
|
||||
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
|
||||
z = z * s + m
|
||||
z = rearrange(
|
||||
z,
|
||||
"... (c pi pj) i j -> ... c (i pi) (j pj)",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
|
||||
@@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||
|
||||
@@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@@ -134,33 +135,34 @@ class Attention(nn.Module):
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@@ -234,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
del img_mod1
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
del txt_mod1
|
||||
|
||||
img_attn_output, txt_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
@@ -246,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del img_modulated
|
||||
del txt_modulated
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
del img_attn_output
|
||||
del txt_attn_output
|
||||
del img_gate1
|
||||
del txt_gate1
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
@@ -413,7 +419,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
@@ -433,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
@@ -588,7 +589,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@@ -601,10 +602,22 @@ class WanModel(torch.nn.Module):
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
@@ -630,7 +643,7 @@ class WanModel(torch.nn.Module):
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
|
||||
@@ -657,51 +657,51 @@ class WanVAE(nn.Module):
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
self.clear_cache()
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
x = patchify(x, patch_size=2)
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
self.clear_cache()
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
first_chunk=True,
|
||||
)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
@@ -715,12 +715,3 @@ class WanVAE(nn.Module):
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
@@ -313,6 +313,15 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.Lumina2):
|
||||
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -134,10 +134,11 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
self.diffusion_model.eval()
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
@@ -196,8 +197,14 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
return timestep
|
||||
@@ -326,6 +333,14 @@ class BaseModel(torch.nn.Module):
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": self.model_config.layer_quant_config
|
||||
}
|
||||
unet_state_dict["_quantization_metadata"] = metadata
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@@ -669,7 +684,6 @@ class Lotus(BaseModel):
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@@ -698,7 +712,6 @@ class StableCascade_C(BaseModel):
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@@ -885,12 +898,13 @@ class Flux(BaseModel):
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
shape = kwargs["noise"].shape
|
||||
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
mask_ref_size = kwargs.get("attention_mask_img_shape", None)
|
||||
if mask_ref_size is not None:
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
|
||||
guidance = kwargs.get("guidance", 3.5)
|
||||
if guidance is not None:
|
||||
@@ -912,9 +926,19 @@ class Flux(BaseModel):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
target_text_len = 512
|
||||
if cross_attn.shape[1] < target_text_len:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@@ -1090,9 +1114,13 @@ class Lumina2(BaseModel):
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
@@ -1523,3 +1551,94 @@ class HunyuanImage21Refiner(HunyuanImage21):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||
return out
|
||||
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
|
||||
if conditioning_byt5small is not None:
|
||||
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
|
||||
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
|
||||
|
||||
return out
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
|
||||
if noise_augmentation > 0:
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(kwargs.get("seed", 0) - 10)
|
||||
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
|
||||
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
|
||||
else:
|
||||
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||
return out
|
||||
|
||||
@@ -6,6 +6,20 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@@ -172,30 +186,68 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
|
||||
# HunyuanVideo 1.5
|
||||
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_cond_type_embedding"] = True
|
||||
else:
|
||||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "flux2"
|
||||
dit_config["axes_dim"] = [32, 32, 32, 32]
|
||||
dit_config["num_heads"] = 48
|
||||
dit_config["mlp_ratio"] = 3.0
|
||||
dit_config["theta"] = 2000
|
||||
dit_config["out_channels"] = 128
|
||||
dit_config["global_modulation"] = True
|
||||
dit_config["vec_in_dim"] = None
|
||||
dit_config["mlp_silu_act"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["ops_bias"] = False
|
||||
dit_config["default_ref_method"] = "index"
|
||||
dit_config["ref_index_scale"] = 10.0
|
||||
patch_size = 1
|
||||
else:
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["qkv_bias"] = True
|
||||
patch_size = 2
|
||||
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["context_in_dim"] = 4096
|
||||
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
w = state_dict[in_key]
|
||||
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
|
||||
txt_in_key = "{}txt_in.weight".format(key_prefix)
|
||||
if txt_in_key in state_dict_keys:
|
||||
w = state_dict[txt_in_key]
|
||||
dit_config["context_in_dim"] = w.shape[1]
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
@@ -213,7 +265,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 32
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
else:
|
||||
@@ -364,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["dim"] = 2304
|
||||
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
|
||||
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
|
||||
dit_config["dim"] = w.shape[0]
|
||||
dit_config["cap_feat_dim"] = w.shape[1]
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
|
||||
if dit_config["dim"] == 2304: # Original Lumina 2
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
dit_config["axes_dims"] = [32, 48, 48]
|
||||
dit_config["axes_lens"] = [1536, 512, 512]
|
||||
dit_config["rope_theta"] = 256.0
|
||||
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||
dit_config["z_image_modulation"] = True
|
||||
dit_config["time_scale"] = 1000.0
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
@@ -701,6 +770,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
|
||||
@@ -89,6 +89,7 @@ if args.deterministic:
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
|
||||
import torch_directml
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
@@ -330,13 +331,21 @@ except:
|
||||
|
||||
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
@@ -344,11 +353,11 @@ try:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
# if torch_version_numeric >= (2, 8):
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
SUPPORT_FP8_OPS = True
|
||||
|
||||
except:
|
||||
@@ -370,6 +379,9 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
@@ -492,6 +504,7 @@ class LoadedModel:
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
|
||||
real_model = self.model.model
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
@@ -676,8 +689,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
lowvram_model_memory = 0.1
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 0.1
|
||||
@@ -925,11 +941,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
if d == torch.float16 and should_use_fp16(device):
|
||||
return d
|
||||
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
# also a problem on RDNA4 except fp32 is also slow there.
|
||||
# This is due to large bf16 convolutions being extremely slow.
|
||||
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
||||
if d == torch.bfloat16 and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
return torch.float32
|
||||
@@ -991,12 +1003,6 @@ def device_supports_non_blocking(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
@@ -1006,54 +1012,72 @@ def force_channels_last():
|
||||
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 1
|
||||
if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
NUM_STREAMS = 0
|
||||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia
|
||||
if is_nvidia():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
NUM_STREAMS = 0
|
||||
|
||||
if NUM_STREAMS > 0:
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
def current_stream(device):
|
||||
if device is None:
|
||||
return None
|
||||
if is_device_cuda(device):
|
||||
return torch.cuda.current_stream()
|
||||
elif is_device_xpu(device):
|
||||
return torch.xpu.current_stream()
|
||||
else:
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS <= 1:
|
||||
if NUM_STREAMS == 0:
|
||||
return None
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
return None
|
||||
|
||||
if device in STREAMS:
|
||||
ss = STREAMS[device]
|
||||
s = ss[stream_counter]
|
||||
#Sync the oldest stream in the queue with the current
|
||||
ss[stream_counter].wait_stream(current_stream(device))
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
elif is_device_xpu(device):
|
||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return ss[stream_counter]
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
s1 = torch.cuda.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.cuda.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_xpu(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
s1 = torch.xpu.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.xpu.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
def sync_stream(device, stream):
|
||||
if stream is None:
|
||||
if stream is None or current_stream(device) is None:
|
||||
return
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.current_stream().wait_stream(stream)
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
@@ -1061,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
@@ -1078,6 +1109,83 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||
else:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
if tensor.is_pinned():
|
||||
#NOTE: Cuda does detect when a tensor is already pinned and would
|
||||
#error below, but there are proven cases where this also queues an error
|
||||
#on the GPU async. So dont trust the CUDA API and guard here
|
||||
return False
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
|
||||
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||
if size_stored is None:
|
||||
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
|
||||
return False
|
||||
|
||||
if size != size_stored:
|
||||
logging.warning("Size of pinned tensor changed")
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
@@ -1330,7 +1438,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||
if manual_cast:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -123,16 +123,39 @@ def move_weight_functions(m, device):
|
||||
return memory
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, patches):
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func
|
||||
self.set_func = set_func
|
||||
|
||||
def __call__(self, weight):
|
||||
intermediate_dtype = weight.dtype
|
||||
if self.convert_func is not None:
|
||||
weight = self.convert_func(weight, inplace=False)
|
||||
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is None:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
else:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is not None:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||
else:
|
||||
return out
|
||||
|
||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||
|
||||
def low_vram_patch_estimate_vram(model, key):
|
||||
weight, set_func, convert_func = get_key_weight(model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
|
||||
def get_key_weight(model, key):
|
||||
set_func = None
|
||||
@@ -217,13 +240,13 @@ class ModelPatcher:
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.force_cast_weights = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.parent = None
|
||||
self.pinned = set()
|
||||
|
||||
self.attachments: dict[str] = {}
|
||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||
@@ -255,12 +278,18 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||
self.model.current_weight_patches_uuid = None
|
||||
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
@@ -268,7 +297,7 @@ class ModelPatcher:
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@@ -280,6 +309,7 @@ class ModelPatcher:
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
n.pinned = self.pinned
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
@@ -436,6 +466,19 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
rope_options["scale_y"] = scale_y
|
||||
rope_options["scale_t"] = scale_t
|
||||
|
||||
rope_options["shift_x"] = shift_x
|
||||
rope_options["shift_y"] = shift_y
|
||||
rope_options["shift_t"] = shift_t
|
||||
|
||||
self.model_options["transformer_options"]["rope_options"] = rope_options
|
||||
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
@@ -604,6 +647,21 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if comfy.model_management.pin_memory(weight):
|
||||
self.pinned.add(key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
if key in self.pinned:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
comfy.model_management.unpin_memory(weight)
|
||||
self.pinned.remove(key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
@@ -616,7 +674,16 @@ class ModelPatcher:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if weight_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
||||
if bias_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@@ -625,25 +692,30 @@ class ModelPatcher:
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
lowvram_mem_counter = 0
|
||||
loading = self._load_list()
|
||||
|
||||
load_completely = []
|
||||
offloaded = []
|
||||
offload_buffer = 0
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
module_mem = x[0]
|
||||
module_offload_mem, module_mem, n, m, params = x
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
|
||||
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
if not lowvram_fits:
|
||||
offload_buffer = potential_offload
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
@@ -657,23 +729,28 @@ class ModelPatcher:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
|
||||
patch_counter += 1
|
||||
|
||||
cast_weight = True
|
||||
offloaded.append((module_mem, n, m, params))
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
if full_load or lowvram_fits:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
else:
|
||||
offload_buffer = potential_offload
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
@@ -697,7 +774,9 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
@@ -705,11 +784,17 @@ class ModelPatcher:
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
for x in offloaded:
|
||||
n = x[1]
|
||||
params = x[3]
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
@@ -718,6 +803,7 @@ class ModelPatcher:
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
@@ -746,6 +832,7 @@ class ModelPatcher:
|
||||
self.eject_model()
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
@@ -770,6 +857,7 @@ class ModelPatcher:
|
||||
self.model.to(device_to)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
@@ -781,20 +869,21 @@ class ModelPatcher:
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
offload_buffer = self.model.model_offload_buffer_memory
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free < memory_freed:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
break
|
||||
module_mem = unload[0]
|
||||
n = unload[1]
|
||||
m = unload[2]
|
||||
params = unload[3]
|
||||
module_offload_mem, module_mem, n, m, params = unload
|
||||
|
||||
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
|
||||
|
||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
@@ -825,11 +914,19 @@ class ModelPatcher:
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
||||
patch_counter += 1
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
||||
patch_counter += 1
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
@@ -837,11 +934,18 @@ class ModelPatcher:
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
offload_buffer = max(offload_buffer, potential_offload)
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
@@ -854,6 +958,9 @@ class ModelPatcher:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
@@ -1241,5 +1348,6 @@ class ModelPatcher:
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
|
||||
@@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
def reshape_sigma(sigma, noise_dim):
|
||||
if sigma.nelement() == 1:
|
||||
return sigma.view(())
|
||||
else:
|
||||
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
@@ -45,12 +51,12 @@ class EPS:
|
||||
|
||||
class V_PREDICTION(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class EDM(V_PREDICTION):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class CONST:
|
||||
@@ -58,15 +64,15 @@ class CONST:
|
||||
return noise
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, latent.ndim)
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class X0(EPS):
|
||||
@@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
|
||||
class COSMOS_RFLOW:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
return noise * (1.0 - sigma)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input * (1.0 - sigma) - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
91
comfy/nested_tensor.py
Normal file
91
comfy/nested_tensor.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors):
|
||||
self.tensors = list(tensors)
|
||||
self.is_nested = True
|
||||
|
||||
def _copy(self):
|
||||
return NestedTensor(self.tensors)
|
||||
|
||||
def apply_operation(self, other, operation):
|
||||
o = self._copy()
|
||||
if isinstance(other, NestedTensor):
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other.tensors[i])
|
||||
else:
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other)
|
||||
return o
|
||||
|
||||
def __add__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x + y)
|
||||
|
||||
def __sub__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x - y)
|
||||
|
||||
def __mul__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x * y)
|
||||
|
||||
# def __itruediv__(self, b):
|
||||
# return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __truediv__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
|
||||
|
||||
def unbind(self):
|
||||
return self.tensors
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
o = self._copy()
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = t.to(*args, **kwargs)
|
||||
return o
|
||||
|
||||
def new_ones(self, *args, **kwargs):
|
||||
return self.tensors[0].new_ones(*args, **kwargs)
|
||||
|
||||
def float(self):
|
||||
return self.to(dtype=torch.float)
|
||||
|
||||
def chunk(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
|
||||
|
||||
def size(self):
|
||||
return self.tensors[0].size()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.tensors[0].shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
dims = 0
|
||||
for t in self.tensors:
|
||||
dims = max(t.ndim, dims)
|
||||
return dims
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.tensors[0].device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
return self.tensors[0].layout
|
||||
|
||||
|
||||
def cat_nested(tensors, *args, **kwargs):
|
||||
cated_tensors = []
|
||||
for i in range(len(tensors[0].tensors)):
|
||||
tens = []
|
||||
for j in range(len(tensors)):
|
||||
tens.append(tensors[j].tensors[i])
|
||||
cated_tensors.append(torch.cat(tens, *args, **kwargs))
|
||||
return NestedTensor(cated_tensors)
|
||||
364
comfy/ops.py
364
comfy/ops.py
@@ -24,13 +24,18 @@ import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
return
|
||||
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
@@ -50,49 +55,94 @@ try:
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
||||
try:
|
||||
if comfy.model_management.is_nvidia():
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
#TODO: change upper bound version once it's fixed'
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
||||
logging.info("working around nvidia conv3d memory bug.")
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if isinstance(input, QuantizedTensor):
|
||||
dtype = input._layout_params["orig_dtype"]
|
||||
else:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
|
||||
if has_function:
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
|
||||
if bias_has_function:
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
if has_function:
|
||||
if weight_has_function or weight.dtype != dtype:
|
||||
with wf_context:
|
||||
weight = weight.to(dtype=dtype)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
return weight, bias
|
||||
if offloadable:
|
||||
return weight, bias, offload_stream
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
if weight is not None:
|
||||
device = weight.device
|
||||
else:
|
||||
if bias is None:
|
||||
return
|
||||
device = bias.device
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
@@ -105,10 +155,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -119,10 +172,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -133,10 +189,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -146,11 +205,23 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
||||
return out
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -161,10 +232,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -176,13 +250,17 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
offload_stream = None
|
||||
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -195,13 +273,18 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -217,12 +300,15 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -238,12 +324,15 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose1d(
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose1d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -258,10 +347,14 @@ class disable_weight_init:
|
||||
output_dtype = out_dtype
|
||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||
out_dtype = None
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
||||
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -312,20 +405,18 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
Legacy FP8 linear function for backward compatibility.
|
||||
Uses QuantizedTensor subclass for dispatch.
|
||||
"""
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
input_shape = input.shape
|
||||
input_dtype = input.dtype
|
||||
if len(input.shape) == 3:
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||
w = w.t()
|
||||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
@@ -337,23 +428,20 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input_shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
return o
|
||||
|
||||
return None
|
||||
|
||||
@@ -373,8 +461,10 @@ class fp8_ops(manual_cast):
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
@@ -402,22 +492,26 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if return_weight:
|
||||
return weight
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
@@ -444,8 +538,142 @@ if CUBLAS_IS_AVAILABLE:
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
|
||||
|
||||
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_layer_quant_config = layer_quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
return weight.dequantize()
|
||||
else:
|
||||
return weight
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
return weight
|
||||
|
||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||
merged_dict[key] = merge_nested_dicts(curr_value, value)
|
||||
elif isinstance(value, list):
|
||||
merged_dict.setdefault(key, []).extend(value)
|
||||
else:
|
||||
|
||||
573
comfy/quant_ops.py
Normal file
573
comfy/quant_ops.py
Normal file
@@ -0,0 +1,573 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
import comfy.float
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
"""
|
||||
Decorator to register a layout-specific operation handler.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
Example:
|
||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||
def fp8_linear(func, args, kwargs):
|
||||
# FP8-specific linear implementation
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
if torch_op not in _LAYOUT_REGISTRY:
|
||||
_LAYOUT_REGISTRY[torch_op] = {}
|
||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_generic_util(torch_op):
|
||||
"""
|
||||
Decorator to register a generic utility that works for all layouts.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||
|
||||
Example:
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
# Works for any layout
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
_GENERIC_UTILS[torch_op] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_layout_from_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg._layout_type
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for item in arg:
|
||||
if isinstance(item, QuantizedTensor):
|
||||
return item._layout_type
|
||||
return None
|
||||
|
||||
|
||||
def _move_layout_params_to_device(params, device):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.to(device=device)
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def _copy_layout_params(params):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.clone()
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||
for k, v in src.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
dst[k].copy_(v, non_blocking=non_blocking)
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Base class for quantization layouts.
|
||||
|
||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||
|
||||
New quantization formats should subclass this and implement the required methods.
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||
|
||||
|
||||
class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Universal quantized tensor that works with any layout.
|
||||
|
||||
This tensor subclass uses a pluggable layout system to support multiple
|
||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||
|
||||
The layout_type determines format-specific behavior, while common operations
|
||||
(detach, clone, to) are handled generically.
|
||||
|
||||
Attributes:
|
||||
_qdata: The quantized tensor data
|
||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
@property
|
||||
def layout_type(self):
|
||||
return self._layout_type
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
"""
|
||||
inner_tensors = ["_qdata"]
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
|
||||
tensor_params = {}
|
||||
non_tensor_params = {}
|
||||
for k, v in self._layout_params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tensor_params[k] = v
|
||||
else:
|
||||
non_tensor_params[k] = v
|
||||
|
||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||
ctx["non_tensor_params"] = non_tensor_params
|
||||
|
||||
for k, v in tensor_params.items():
|
||||
attr_name = f"_layout_param_{k}"
|
||||
object.__setattr__(self, attr_name, v)
|
||||
inner_tensors.append(attr_name)
|
||||
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||
"""
|
||||
Tensor unflattening protocol for proper device movement.
|
||||
Reconstructs the QuantizedTensor after device movement.
|
||||
"""
|
||||
layout_type = ctx["layout_type"]
|
||||
layout_params = dict(ctx["non_tensor_params"])
|
||||
|
||||
for key in ctx["tensor_param_keys"]:
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||
return cls(qdata, layout_type, layout_params)
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||
if func in _GENERIC_UTILS:
|
||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||
|
||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||
layout_type = _get_layout_from_args(args)
|
||||
if layout_type and func in _LAYOUT_REGISTRY:
|
||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||
def dequant_arg(arg):
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg.dequantize()
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return type(arg)(dequant_arg(a) for a in arg)
|
||||
return arg
|
||||
|
||||
new_args = dequant_arg(args)
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._qdata.data_ptr()
|
||||
|
||||
def is_pinned(self):
|
||||
return self._qdata.is_pinned()
|
||||
|
||||
def is_contiguous(self, *arg, **kwargs):
|
||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# ==============================================================================
|
||||
|
||||
def _create_transformed_qtensor(qt, transform_fn):
|
||||
new_data = transform_fn(qt._qdata)
|
||||
new_params = _copy_layout_params(qt._layout_params)
|
||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_dtype is not None and target_dtype != qt.dtype:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
f"but not supported. Ignoring layout."
|
||||
)
|
||||
|
||||
# Handle device transfer
|
||||
current_device = qt._qdata.device
|
||||
if target_device is not None:
|
||||
# Normalize device for comparison
|
||||
if isinstance(target_device, str):
|
||||
target_device = torch.device(target_device)
|
||||
if isinstance(current_device, str):
|
||||
current_device = torch.device(current_device)
|
||||
|
||||
if target_device != current_device:
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||
return qt
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.clone.default)
|
||||
def generic_clone(func, args, kwargs):
|
||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||
def generic_to_copy(func, args, kwargs):
|
||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
op_name="_to_copy"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||
def generic_to_dtype_layout(func, args, kwargs):
|
||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
target_layout=kwargs.get('layout', None),
|
||||
op_name="to"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.copy_.default)
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
non_blocking = args[2] if len(args) > 2 else False
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
return qt_dest
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype)
|
||||
def generic_to_dtype(func, args, kwargs):
|
||||
"""Handle .to(dtype) calls - dtype conversion only."""
|
||||
src = args[0]
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
||||
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
||||
src._layout_params["orig_dtype"] = target_dtype
|
||||
return src
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||
def generic_empty_like(func, args, kwargs):
|
||||
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
# Create empty tensor with same shape and dtype as the quantized data
|
||||
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||
|
||||
# Handle device transfer for layout params
|
||||
target_device = kwargs.get('device', new_qdata.device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
|
||||
# Update orig_dtype if dtype is specified
|
||||
new_params['orig_dtype'] = hp_dtype
|
||||
|
||||
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layout + Operation Handlers
|
||||
# ==============================================================================
|
||||
class TensorCoreFP8Layout(QuantizedLayout):
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if scale is None:
|
||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||
else:
|
||||
lp_amax = torch.finfo(dtype).max
|
||||
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
||||
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return tensor, layout_params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
return qtensor._qdata, qtensor._layout_params['scale']
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||
},
|
||||
}
|
||||
|
||||
LAYOUTS = {
|
||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||
}
|
||||
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||
def fp8_linear(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
out_dtype = kwargs.get("out_dtype")
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
weight_t = plain_weight.t()
|
||||
|
||||
tensor_2d = False
|
||||
if len(plain_input.shape) == 2:
|
||||
tensor_2d = True
|
||||
plain_input = plain_input.unsqueeze(1)
|
||||
|
||||
input_shape = plain_input.shape
|
||||
if len(input_shape) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
output = torch._scaled_mm(
|
||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
|
||||
if not tensor_2d:
|
||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||
|
||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
output_scale = scale_a * scale_b
|
||||
output_params = {
|
||||
'scale': output_scale,
|
||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||
}
|
||||
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||
else:
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
|
||||
# Case 2: DQ Fallback
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
input_tensor = input_tensor.dequantize()
|
||||
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
plain_input.contiguous(),
|
||||
plain_weight,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||
def fp8_addmm(func, args, kwargs):
|
||||
input_tensor = args[1]
|
||||
weight = args[2]
|
||||
bias = args[0]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
if isinstance(args[2], QuantizedTensor):
|
||||
a[2] = args[2].dequantize()
|
||||
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||
def fp8_mm(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||
def fp8_func(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
ar = list(args)
|
||||
ar[0] = plain_input
|
||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||
return func(*args, **kwargs)
|
||||
@@ -4,13 +4,9 @@ import comfy.samplers
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
import logging
|
||||
import comfy.nested_tensor
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
@@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return torch.cat(noises, axis=0)
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
if latent_image.is_nested:
|
||||
tensors = latent_image.unbind()
|
||||
noises = []
|
||||
for t in tensors:
|
||||
noises.append(prepare_noise_inner(t, generator, noise_inds))
|
||||
noises = comfy.nested_tensor.NestedTensor(noises)
|
||||
else:
|
||||
noises = prepare_noise_inner(latent_image, generator, noise_inds)
|
||||
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
|
||||
@@ -306,17 +306,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
@@ -789,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||
@@ -799,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for k in conds:
|
||||
@@ -969,11 +962,11 @@ class CFGGuider:
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
|
||||
|
||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||
@@ -987,7 +980,7 @@ class CFGGuider:
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
@@ -1001,7 +994,7 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
@@ -1014,6 +1007,12 @@ class CFGGuider:
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
if latent_image.is_nested:
|
||||
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
|
||||
noise, _ = comfy.utils.pack_latents(noise.unbind())
|
||||
else:
|
||||
latent_shapes = [latent_image.shape]
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
@@ -1033,7 +1032,7 @@ class CFGGuider:
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||
self.model_options = orig_model_options
|
||||
@@ -1041,6 +1040,9 @@ class CFGGuider:
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
|
||||
if len(latent_shapes) > 1:
|
||||
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
149
comfy/sd.py
149
comfy/sd.py
@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
import yaml
|
||||
import math
|
||||
@@ -51,6 +52,7 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -58,6 +60,8 @@ import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
import comfy.taesd.taehv
|
||||
import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
@@ -142,6 +146,9 @@ class CLIP:
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.patcher.get_ram_usage()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@@ -275,8 +282,13 @@ class VAE:
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
else:
|
||||
VAE_KL_MEM_RATIO = 1.0
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
@@ -287,10 +299,12 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
self.size = None
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@@ -345,7 +359,7 @@ class VAE:
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
elif sd['decoder.conv_in.weight'].shape[1] == 32:
|
||||
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
@@ -371,6 +385,17 @@ class VAE:
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
if 'decoder.post_quant_conv.weight' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
|
||||
|
||||
if 'bn.running_mean' in sd:
|
||||
ddconfig["batch_norm_latent"] = True
|
||||
self.downscale_ratio *= 2
|
||||
self.upscale_ratio *= 2
|
||||
self.latent_channels *= 4
|
||||
old_memory_used_decode = self.memory_used_decode
|
||||
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
|
||||
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
else:
|
||||
@@ -430,20 +455,20 @@ class VAE:
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.latent_channels = 64
|
||||
self.latent_channels = 32
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.not_video = True
|
||||
self.not_video = False
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
@@ -485,13 +510,14 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||
else: # Wan 2.1 VAE
|
||||
dim = sd["decoder.head.0.gamma"].shape[0]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
@@ -542,6 +568,54 @@ class VAE:
|
||||
self.latent_channels = 3
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
|
||||
sample_rate = 16000
|
||||
if sample_rate == 16000:
|
||||
mode = '16k'
|
||||
else:
|
||||
mode = '44k'
|
||||
|
||||
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
|
||||
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 20
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.downscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
|
||||
self.latent_channels = sd["decoder.1.weight"].shape[1]
|
||||
self.latent_dim = 3
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
if self.latent_channels == 48: # Wan 2.2
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
latent_format=comfy.latent_formats.HunyuanVideo
|
||||
else:
|
||||
latent_format=None # lighttaew2_1 doesn't need scaling
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
|
||||
self.process_input = self.process_output = lambda image: image
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -569,12 +643,25 @@ class VAE:
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
def model_size(self):
|
||||
if self.size is not None:
|
||||
return self.size
|
||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
if not self.crop_input:
|
||||
return pixels
|
||||
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
@@ -868,12 +955,18 @@ class CLIPType(Enum):
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||
if metadata is not None:
|
||||
quant_metadata = metadata.get("_quantization_metadata", None)
|
||||
if quant_metadata is not None:
|
||||
sd["_quantization_metadata"] = quant_metadata
|
||||
clip_data.append(sd)
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
||||
@@ -891,6 +984,10 @@ class TEModel(Enum):
|
||||
QWEN25_7B = 11
|
||||
BYT5_SMALL_GLYPH = 12
|
||||
GEMMA_3_4B = 13
|
||||
MISTRAL3_24B = 14
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -923,6 +1020,15 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.QWEN3_4B
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if weight.shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
else:
|
||||
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
||||
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
|
||||
@@ -1037,6 +1143,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@@ -1083,6 +1196,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@@ -1095,6 +1211,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
if "_quantization_metadata" in c:
|
||||
c.pop("_quantization_metadata")
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
@@ -1233,7 +1351,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
@@ -1267,7 +1385,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
@@ -1301,7 +1419,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
if model_config.layer_quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
@@ -1317,8 +1438,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
|
||||
@@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
@@ -109,13 +108,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
quantization_metadata = model_options.get("quantization_metadata", None)
|
||||
|
||||
if operations is None:
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
layer_quant_config = None
|
||||
if quantization_metadata is not None:
|
||||
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
||||
|
||||
if layer_quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
||||
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
# Fallback to scaled_fp8_ops for backward compatibility
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
@@ -154,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if self.layer == "all":
|
||||
if isinstance(self.layer, list) or self.layer == "all":
|
||||
pass
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
@@ -256,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
if self.layer == "all":
|
||||
if isinstance(self.layer, list):
|
||||
intermediate_output = self.layer
|
||||
elif self.layer == "all":
|
||||
intermediate_output = "all"
|
||||
else:
|
||||
intermediate_output = self.layer_idx
|
||||
@@ -460,7 +471,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
@@ -468,6 +479,7 @@ class SDTokenizer:
|
||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
self.pad_left = pad_left
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
@@ -522,6 +534,12 @@ class SDTokenizer:
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
def pad_tokens(self, tokens, amount):
|
||||
if self.pad_left:
|
||||
for i in range(amount):
|
||||
tokens.insert(0, (self.pad_token, 1.0, 0))
|
||||
else:
|
||||
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||
'''
|
||||
@@ -600,7 +618,7 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
self.pad_tokens(batch, remaining_length)
|
||||
#start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
@@ -614,11 +632,11 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if min_padding is not None:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||
self.pad_tokens(batch, min_padding)
|
||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
self.pad_tokens(batch, self.max_length - len(batch))
|
||||
if min_length is not None and len(batch) < min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||
self.pad_tokens(batch, min_length - len(batch))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
|
||||
@@ -21,6 +21,7 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -741,6 +742,37 @@ class FluxSchnell(Flux):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.02,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@@ -963,7 +995,7 @@ class Lumina2(supported_models_base.BASE):
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
memory_usage_factor = 1.4
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
@@ -982,6 +1014,24 @@ class Lumina2(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
||||
|
||||
class ZImage(Lumina2):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
"dim": 3840,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.7
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1374,6 +1424,55 @@ class HunyuanImage21Refiner(HunyuanVideo):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 7.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
"in_channels": 98,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
||||
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -50,6 +50,7 @@ class BASE:
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
||||
171
comfy/taesd/taehv.py
Normal file
171
comfy/taesd/taehv.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm.auto import tqdm
|
||||
from collections import namedtuple, deque
|
||||
|
||||
import comfy.ops
|
||||
operations=comfy.ops.disable_weight_init
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class MemBlock(nn.Module):
|
||||
def __init__(self, n_in, n_out, act_func):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
|
||||
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.act = act_func
|
||||
def forward(self, x, past):
|
||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
||||
|
||||
class TPool(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
||||
|
||||
class TGrow(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
|
||||
B, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(B*T, C, H, W)
|
||||
# parallel over input timesteps, iterate over blocks
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
_x = x.reshape(B, T, C, H, W)
|
||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
x = x.view(B, T, C, H, W)
|
||||
else:
|
||||
out = []
|
||||
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
|
||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||
mem = [None] * len(model)
|
||||
while work_queue:
|
||||
xt, i = work_queue.popleft()
|
||||
if i == 0:
|
||||
progress_bar.update(1)
|
||||
if i == len(model):
|
||||
out.append(xt)
|
||||
del xt
|
||||
else:
|
||||
b = model[i]
|
||||
if isinstance(b, MemBlock):
|
||||
if mem[i] is None:
|
||||
xt_new = b(xt, xt * 0)
|
||||
mem[i] = xt.detach().clone()
|
||||
else:
|
||||
xt_new = b(xt, mem[i])
|
||||
mem[i] = xt.detach().clone()
|
||||
del xt
|
||||
work_queue.appendleft(TWorkItem(xt_new, i+1))
|
||||
elif isinstance(b, TPool):
|
||||
if mem[i] is None:
|
||||
mem[i] = []
|
||||
mem[i].append(xt.detach().clone())
|
||||
if len(mem[i]) == b.stride:
|
||||
B, C, H, W = xt.shape
|
||||
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
|
||||
mem[i] = []
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
elif isinstance(b, TGrow):
|
||||
xt = b(xt)
|
||||
NT, C, H, W = xt.shape
|
||||
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
|
||||
work_queue.appendleft(TWorkItem(xt_next, i+1))
|
||||
del xt
|
||||
else:
|
||||
xt = b(xt)
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
progress_bar.close()
|
||||
x = torch.stack(out, 1)
|
||||
return x
|
||||
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
|
||||
super().__init__()
|
||||
self.image_channels = 3
|
||||
self.patch_size = 1
|
||||
self.latent_channels = latent_channels
|
||||
self.parallel = parallel
|
||||
self.latent_format = latent_format
|
||||
self.show_progress_bar = show_progress_bar
|
||||
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
self.patch_size = 2
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
act_func = nn.ReLU(inplace=True)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
)
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
n_pad = 4 - x.shape[1] % 4
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
@@ -1,10 +1,13 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.llama
|
||||
import comfy.model_management
|
||||
from transformers import T5TokenizerFast
|
||||
from transformers import T5TokenizerFast, LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -68,3 +71,106 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
def load_mistral_tokenizer(data):
|
||||
if torch.is_tensor(data):
|
||||
data = data.numpy().tobytes()
|
||||
|
||||
try:
|
||||
from transformers.integrations.mistral import MistralConverter
|
||||
except ModuleNotFoundError:
|
||||
from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter
|
||||
|
||||
mistral_vocab = json.loads(data)
|
||||
|
||||
special_tokens = {}
|
||||
vocab = {}
|
||||
|
||||
max_vocab = mistral_vocab["config"]["default_vocab_size"]
|
||||
max_vocab -= len(mistral_vocab["special_tokens"])
|
||||
|
||||
for w in mistral_vocab["vocab"]:
|
||||
r = w["rank"]
|
||||
if r >= max_vocab:
|
||||
continue
|
||||
|
||||
vocab[base64.b64decode(w["token_bytes"])] = r
|
||||
|
||||
for w in mistral_vocab["special_tokens"]:
|
||||
if "token_bytes" in w:
|
||||
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
|
||||
else:
|
||||
special_tokens[w["token_str"]] = w["rank"]
|
||||
|
||||
all_special = []
|
||||
for v in special_tokens:
|
||||
all_special.append(v)
|
||||
|
||||
special_tokens.update(vocab)
|
||||
vocab = special_tokens
|
||||
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
|
||||
|
||||
class MistralTokenizerClass:
|
||||
@staticmethod
|
||||
def from_pretrained(path, **kwargs):
|
||||
return LlamaTokenizerFast(**kwargs)
|
||||
|
||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"tekken_model": self.tekken_data}
|
||||
|
||||
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
|
||||
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = {}
|
||||
num_layers = model_options.get("num_layers", None)
|
||||
if num_layers is not None:
|
||||
textmodel_json_config["num_hidden_layers"] = num_layers
|
||||
if num_layers < 40:
|
||||
textmodel_json_config["final_norm"] = False
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Flux2TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
|
||||
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
|
||||
out = out.movedim(1, 2)
|
||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||
return out, pooled, extra
|
||||
|
||||
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Flux2TEModel_(Flux2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if pruned:
|
||||
model_options = model_options.copy()
|
||||
model_options["num_layers"] = 30
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Flux2TEModel_
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.llama
|
||||
from .hunyuan_image import HunyuanImageTokenizer
|
||||
from transformers import LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
@@ -17,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
if "_quantization_metadata" in state_dict:
|
||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -73,6 +77,14 @@ class HunyuanVideoTokenizer:
|
||||
return {}
|
||||
|
||||
|
||||
class HunyuanVideo15Tokenizer(HunyuanImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs)
|
||||
|
||||
class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
|
||||
@@ -32,6 +32,29 @@ class Llama2Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Mistral3Small24BConfig:
|
||||
vocab_size: int = 131072
|
||||
hidden_size: int = 5120
|
||||
intermediate_size: int = 32768
|
||||
num_hidden_layers: int = 40
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 1000000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@@ -53,6 +76,29 @@ class Qwen25_3BConfig:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@@ -74,6 +120,7 @@ class Qwen25_7BVLI_Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@@ -96,6 +143,7 @@ class Gemma2_2B_Config:
|
||||
k_norm = None
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@@ -118,6 +166,7 @@ class Gemma3_4B_Config:
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [False, False, False, False, False, 1024]
|
||||
rope_scale = [1.0, 8.0]
|
||||
final_norm: bool = True
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@@ -366,7 +415,12 @@ class Llama2_(nn.Module):
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
@@ -402,8 +456,12 @@ class Llama2_(nn.Module):
|
||||
|
||||
intermediate = None
|
||||
all_intermediate = None
|
||||
only_layers = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output == "all":
|
||||
if isinstance(intermediate_output, list):
|
||||
all_intermediate = []
|
||||
only_layers = set(intermediate_output)
|
||||
elif intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
@@ -411,7 +469,8 @@ class Llama2_(nn.Module):
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
@@ -421,14 +480,17 @@ class Llama2_(nn.Module):
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
x = self.norm(x)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
if only_layers is None or ((i + 1) in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
@@ -453,6 +515,15 @@ class Llama2(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Mistral3Small24B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Mistral3Small24BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@@ -462,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@@ -179,36 +179,36 @@
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<|img|>",
|
||||
"content": "<tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151666": {
|
||||
"content": "<|endofimg|>",
|
||||
"content": "</tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151667": {
|
||||
"content": "<|meta|>",
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151668": {
|
||||
"content": "<|endofmeta|>",
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
|
||||
@@ -17,12 +17,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if text.startswith('<|start_header_id|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
|
||||
48
comfy/text_encoders/z_image.py
Normal file
48
comfy/text_encoders/z_image.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class ZImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
||||
class ZImageTEModel_(ZImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return ZImageTEModel_
|
||||
@@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
from numpy.core.multiarray import scalar
|
||||
def scalar(*args, **kwargs):
|
||||
from numpy.core.multiarray import scalar as sc
|
||||
return sc(*args, **kwargs)
|
||||
scalar.__module__ = "numpy.core.multiarray"
|
||||
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
from _codecs import encode
|
||||
@@ -671,6 +675,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
hidden_size = mmdit_config.get("dim", 0)
|
||||
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
|
||||
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
|
||||
key_map = {}
|
||||
|
||||
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attention.".format(prefix_from)
|
||||
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {
|
||||
"attention.norm_q.weight": "attention.q_norm.weight",
|
||||
"attention.norm_k.weight": "attention.k_norm.weight",
|
||||
"attention.to_out.0.weight": "attention.out.weight",
|
||||
"attention.to_out.0.bias": "attention.out.bias",
|
||||
"attention_norm1.weight": "attention_norm1.weight",
|
||||
"attention_norm2.weight": "attention_norm2.weight",
|
||||
"feed_forward.w1.weight": "feed_forward.w1.weight",
|
||||
"feed_forward.w2.weight": "feed_forward.w2.weight",
|
||||
"feed_forward.w3.weight": "feed_forward.w3.weight",
|
||||
"ffn_norm1.weight": "ffn_norm1.weight",
|
||||
"ffn_norm2.weight": "ffn_norm2.weight",
|
||||
}
|
||||
if has_adaln:
|
||||
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
|
||||
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
|
||||
for k, v in block_map.items():
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
|
||||
|
||||
for i in range(n_layers):
|
||||
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
|
||||
|
||||
for i in range(n_context_refiner):
|
||||
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
|
||||
|
||||
for i in range(n_noise_refiner):
|
||||
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
|
||||
|
||||
MAP_BASIC = [
|
||||
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
|
||||
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
|
||||
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
|
||||
("x_embedder.weight", "all_x_embedder.2-1.weight"),
|
||||
("x_embedder.bias", "all_x_embedder.2-1.bias"),
|
||||
("x_pad_token", "x_pad_token"),
|
||||
("cap_embedder.0.weight", "cap_embedder.0.weight"),
|
||||
("cap_embedder.1.weight", "cap_embedder.1.weight"),
|
||||
("cap_embedder.1.bias", "cap_embedder.1.bias"),
|
||||
("cap_pad_token", "cap_pad_token"),
|
||||
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
|
||||
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
|
||||
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
|
||||
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
|
||||
]
|
||||
|
||||
for c, diffusers in MAP_BASIC:
|
||||
key_map[diffusers] = "{}{}".format(output_prefix, c)
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
@@ -1102,3 +1172,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
||||
def pack_latents(latents):
|
||||
latent_shapes = []
|
||||
tensors = []
|
||||
for tensor in latents:
|
||||
latent_shapes.append(tensor.shape)
|
||||
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
|
||||
|
||||
latent = torch.cat(tensors, dim=-1)
|
||||
return latent, latent_shapes
|
||||
|
||||
def unpack_latents(combined_latent, latent_shapes):
|
||||
if len(latent_shapes) > 1:
|
||||
output_tensors = []
|
||||
for shape in latent_shapes:
|
||||
cut = math.prod(shape[1:])
|
||||
tens = combined_latent[:, :, :cut]
|
||||
combined_latent = combined_latent[:, :, cut:]
|
||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
@@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
lora_diff = torch.mm(
|
||||
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||
).reshape(weight.shape)
|
||||
del mat1, mat2
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args
|
||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
@@ -220,11 +220,18 @@ class AsyncToSyncConverter:
|
||||
self._async_instance = async_class(*args, **kwargs)
|
||||
|
||||
# Handle annotated class attributes (like execution: Execution)
|
||||
# Get all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
# Get all annotations from the class hierarchy and resolve string annotations
|
||||
try:
|
||||
# get_type_hints resolves string annotations to actual type objects
|
||||
# This handles classes using 'from __future__ import annotations'
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations if get_type_hints fails
|
||||
# (e.g., for undefined forward references)
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
# For each annotated attribute, check if it needs to be created or wrapped
|
||||
for attr_name, attr_type in all_annotations.items():
|
||||
@@ -625,15 +632,19 @@ class AsyncToSyncConverter:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
# Get resolved type hints to handle string annotations
|
||||
try:
|
||||
type_hints = get_type_hints(async_class)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
|
||||
# Look for class attributes that are classes
|
||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||
if isinstance(attr, type) and not name.startswith("_"):
|
||||
class_attributes.append((name, attr))
|
||||
elif (
|
||||
hasattr(async_class, "__annotations__")
|
||||
and name in async_class.__annotations__
|
||||
):
|
||||
annotation = async_class.__annotations__[name]
|
||||
elif name in type_hints:
|
||||
# Use resolved type hint instead of raw annotation
|
||||
annotation = type_hints[name]
|
||||
if isinstance(annotation, type):
|
||||
class_attributes.append((name, annotation))
|
||||
|
||||
@@ -908,11 +919,15 @@ class AsyncToSyncConverter:
|
||||
attribute_mappings = {}
|
||||
|
||||
# First check annotations for typed attributes (including from parent classes)
|
||||
# Collect all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
# Resolve string annotations to actual types
|
||||
try:
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||
for class_name, class_type in class_attributes:
|
||||
|
||||
@@ -7,9 +7,9 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io as io
|
||||
from . import _ui as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
@@ -104,6 +104,8 @@ class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
@@ -114,6 +116,10 @@ if TYPE_CHECKING:
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
# create new aliases for io and ui
|
||||
IO = io
|
||||
UI = ui
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
@@ -121,4 +127,8 @@ __all__ = [
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
@@ -23,7 +24,7 @@ class VideoInput(ABC):
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
path: Union[str, IO[bytes]],
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
@@ -72,6 +73,33 @@ class VideoInput(ABC):
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video.
|
||||
|
||||
Default implementation uses :meth:`get_components`, which may require
|
||||
loading all frames into memory. File-based implementations should
|
||||
override this method and use container/stream metadata instead.
|
||||
|
||||
Returns:
|
||||
Total number of frames as an integer.
|
||||
"""
|
||||
return int(self.get_components().images.shape[0])
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the frame rate of the video.
|
||||
|
||||
Default implementation materializes the video into memory via
|
||||
`get_components()`. Subclasses that can inspect the underlying
|
||||
container (e.g. `VideoFromFile`) should override this with a more
|
||||
efficient implementation.
|
||||
|
||||
Returns:
|
||||
Frame rate as a Fraction.
|
||||
"""
|
||||
return self.get_components().frame_rate
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
@@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video without materializing them as
|
||||
torch tensors.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the average frame rate of the video using container metadata
|
||||
without decoding all frames.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
|
||||
if video_stream.average_rate:
|
||||
return Fraction(video_stream.average_rate)
|
||||
|
||||
# Fallback: estimate from frames + duration if available
|
||||
if video_stream.frames and container.duration:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
if duration_seconds > 0:
|
||||
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
|
||||
|
||||
# Last resort: match get_components_internal default
|
||||
return Fraction(1)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
@@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
@@ -264,7 +336,10 @@ class VideoFromComponents(VideoInput):
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
|
||||
@@ -27,6 +27,7 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
||||
|
||||
@@ -336,11 +337,25 @@ class Combo(ComfyTypeIO):
|
||||
class Input(WidgetInput):
|
||||
"""Combo input (dropdown)."""
|
||||
Type = str
|
||||
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: str=None, control_after_generate: bool=None,
|
||||
upload: UploadType=None, image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
socketless: bool=None):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
options: list[str] | list[int] | type[Enum] = None,
|
||||
display_name: str=None,
|
||||
optional=False,
|
||||
tooltip: str=None,
|
||||
lazy: bool=None,
|
||||
default: str | int | Enum = None,
|
||||
control_after_generate: bool=None,
|
||||
upload: UploadType=None,
|
||||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
socketless: bool=None,
|
||||
):
|
||||
if isinstance(options, type) and issubclass(options, Enum):
|
||||
options = [v.value for v in options]
|
||||
if isinstance(default, Enum):
|
||||
default = default.value
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
self.multiselect = False
|
||||
self.options = options
|
||||
@@ -614,6 +629,10 @@ class UpscaleModel(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = ImageModelDescriptor
|
||||
|
||||
@comfytype(io_type="LATENT_UPSCALE_MODEL")
|
||||
class LatentUpscaleModel(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO")
|
||||
class Audio(ComfyTypeIO):
|
||||
class AudioDict(TypedDict):
|
||||
@@ -642,11 +661,11 @@ class LossMap(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="VOXEL")
|
||||
class Voxel(ComfyTypeIO):
|
||||
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = VOXEL
|
||||
|
||||
@comfytype(io_type="MESH")
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = MESH
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
@@ -1568,78 +1587,78 @@ class _UIOutput(ABC):
|
||||
...
|
||||
|
||||
|
||||
class _IO:
|
||||
FolderType = FolderType
|
||||
UploadType = UploadType
|
||||
RemoteOptions = RemoteOptions
|
||||
NumberDisplay = NumberDisplay
|
||||
__all__ = [
|
||||
"FolderType",
|
||||
"UploadType",
|
||||
"RemoteOptions",
|
||||
"NumberDisplay",
|
||||
|
||||
comfytype = staticmethod(comfytype)
|
||||
Custom = staticmethod(Custom)
|
||||
Input = Input
|
||||
WidgetInput = WidgetInput
|
||||
Output = Output
|
||||
ComfyTypeI = ComfyTypeI
|
||||
ComfyTypeIO = ComfyTypeIO
|
||||
#---------------------------------
|
||||
"comfytype",
|
||||
"Custom",
|
||||
"Input",
|
||||
"WidgetInput",
|
||||
"Output",
|
||||
"ComfyTypeI",
|
||||
"ComfyTypeIO",
|
||||
# Supported Types
|
||||
Boolean = Boolean
|
||||
Int = Int
|
||||
Float = Float
|
||||
String = String
|
||||
Combo = Combo
|
||||
MultiCombo = MultiCombo
|
||||
Image = Image
|
||||
WanCameraEmbedding = WanCameraEmbedding
|
||||
Webcam = Webcam
|
||||
Mask = Mask
|
||||
Latent = Latent
|
||||
Conditioning = Conditioning
|
||||
Sampler = Sampler
|
||||
Sigmas = Sigmas
|
||||
Noise = Noise
|
||||
Guider = Guider
|
||||
Clip = Clip
|
||||
ControlNet = ControlNet
|
||||
Vae = Vae
|
||||
Model = Model
|
||||
ClipVision = ClipVision
|
||||
ClipVisionOutput = ClipVisionOutput
|
||||
AudioEncoder = AudioEncoder
|
||||
AudioEncoderOutput = AudioEncoderOutput
|
||||
StyleModel = StyleModel
|
||||
Gligen = Gligen
|
||||
UpscaleModel = UpscaleModel
|
||||
Audio = Audio
|
||||
Video = Video
|
||||
SVG = SVG
|
||||
LoraModel = LoraModel
|
||||
LossMap = LossMap
|
||||
Voxel = Voxel
|
||||
Mesh = Mesh
|
||||
Hooks = Hooks
|
||||
HookKeyframes = HookKeyframes
|
||||
TimestepsRange = TimestepsRange
|
||||
LatentOperation = LatentOperation
|
||||
FlowControl = FlowControl
|
||||
Accumulation = Accumulation
|
||||
Load3DCamera = Load3DCamera
|
||||
Load3D = Load3D
|
||||
Load3DAnimation = Load3DAnimation
|
||||
Photomaker = Photomaker
|
||||
Point = Point
|
||||
FaceAnalysis = FaceAnalysis
|
||||
BBOX = BBOX
|
||||
SEGS = SEGS
|
||||
AnyType = AnyType
|
||||
MultiType = MultiType
|
||||
#---------------------------------
|
||||
HiddenHolder = HiddenHolder
|
||||
Hidden = Hidden
|
||||
NodeInfoV1 = NodeInfoV1
|
||||
NodeInfoV3 = NodeInfoV3
|
||||
Schema = Schema
|
||||
ComfyNode = ComfyNode
|
||||
NodeOutput = NodeOutput
|
||||
add_to_dict_v1 = staticmethod(add_to_dict_v1)
|
||||
add_to_dict_v3 = staticmethod(add_to_dict_v3)
|
||||
"Boolean",
|
||||
"Int",
|
||||
"Float",
|
||||
"String",
|
||||
"Combo",
|
||||
"MultiCombo",
|
||||
"Image",
|
||||
"WanCameraEmbedding",
|
||||
"Webcam",
|
||||
"Mask",
|
||||
"Latent",
|
||||
"Conditioning",
|
||||
"Sampler",
|
||||
"Sigmas",
|
||||
"Noise",
|
||||
"Guider",
|
||||
"Clip",
|
||||
"ControlNet",
|
||||
"Vae",
|
||||
"Model",
|
||||
"ClipVision",
|
||||
"ClipVisionOutput",
|
||||
"AudioEncoder",
|
||||
"AudioEncoderOutput",
|
||||
"StyleModel",
|
||||
"Gligen",
|
||||
"UpscaleModel",
|
||||
"Audio",
|
||||
"Video",
|
||||
"SVG",
|
||||
"LoraModel",
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"Hooks",
|
||||
"HookKeyframes",
|
||||
"TimestepsRange",
|
||||
"LatentOperation",
|
||||
"FlowControl",
|
||||
"Accumulation",
|
||||
"Load3DCamera",
|
||||
"Load3D",
|
||||
"Load3DAnimation",
|
||||
"Photomaker",
|
||||
"Point",
|
||||
"FaceAnalysis",
|
||||
"BBOX",
|
||||
"SEGS",
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
"NodeInfoV1",
|
||||
"NodeInfoV3",
|
||||
"Schema",
|
||||
"ComfyNode",
|
||||
"NodeOutput",
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
]
|
||||
|
||||
@@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
|
||||
return {"text": (self.value,)}
|
||||
|
||||
|
||||
class _UI:
|
||||
SavedResult = SavedResult
|
||||
SavedImages = SavedImages
|
||||
SavedAudios = SavedAudios
|
||||
ImageSaveHelper = ImageSaveHelper
|
||||
AudioSaveHelper = AudioSaveHelper
|
||||
PreviewImage = PreviewImage
|
||||
PreviewMask = PreviewMask
|
||||
PreviewAudio = PreviewAudio
|
||||
PreviewVideo = PreviewVideo
|
||||
PreviewUI3D = PreviewUI3D
|
||||
PreviewText = PreviewText
|
||||
__all__ = [
|
||||
"SavedResult",
|
||||
"SavedImages",
|
||||
"SavedAudios",
|
||||
"ImageSaveHelper",
|
||||
"AudioSaveHelper",
|
||||
"PreviewImage",
|
||||
"PreviewMask",
|
||||
"PreviewAudio",
|
||||
"PreviewVideo",
|
||||
"PreviewUI3D",
|
||||
"PreviewText",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
]
|
||||
|
||||
12
comfy_api/latest/_util/geometry_types.py
Normal file
12
comfy_api/latest/_util/geometry_types.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data: torch.Tensor):
|
||||
self.data = data
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
@@ -1,704 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiClient,
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
UploadRequest,
|
||||
UploadResponse,
|
||||
)
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
async def download_url_to_video_output(
|
||||
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
video_url: The URL of the video to download.
|
||||
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return VideoFromFile(video_io)
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
async def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
response: The response to validate and cast.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
"""
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise ValueError("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
for img_data in data:
|
||||
img_bytes: bytes
|
||||
if img_data.b64_json:
|
||||
img_bytes = base64.b64decode(img_data.b64_json)
|
||||
elif img_data.url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||
async with session.get(img_data.url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("Failed to download generated image")
|
||||
img_bytes = await resp.read()
|
||||
else:
|
||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
def validate_aspect_ratio(
|
||||
aspect_ratio: str,
|
||||
minimum_ratio: float,
|
||||
maximum_ratio: float,
|
||||
minimum_ratio_str: str,
|
||||
maximum_ratio_str: str,
|
||||
) -> float:
|
||||
"""Validates and casts an aspect ratio string to a float.
|
||||
|
||||
Args:
|
||||
aspect_ratio: The aspect ratio string to validate.
|
||||
minimum_ratio: The minimum aspect ratio.
|
||||
maximum_ratio: The maximum aspect ratio.
|
||||
minimum_ratio_str: The minimum aspect ratio string.
|
||||
maximum_ratio_str: The maximum aspect ratio string.
|
||||
|
||||
Returns:
|
||||
The validated and cast aspect ratio.
|
||||
|
||||
Raises:
|
||||
Exception: If the aspect ratio is not valid.
|
||||
"""
|
||||
# get ratio values
|
||||
numbers = aspect_ratio.split(":")
|
||||
if len(numbers) != 2:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
||||
)
|
||||
try:
|
||||
numerator = int(numbers[0])
|
||||
denominator = int(numbers[1])
|
||||
except ValueError as exc:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
||||
) from exc
|
||||
calculated_ratio = numerator / denominator
|
||||
# if not close to minimum and maximum, check bounds
|
||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
||||
calculated_ratio, maximum_ratio
|
||||
):
|
||||
if calculated_ratio < minimum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
if calculated_ratio > maximum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
return aspect_ratio
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
url: The URL to download.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
headers = {}
|
||||
if url.startswith("/proxy/"):
|
||||
url = str(args.comfy_api_base).rstrip("/") + url
|
||||
auth_token = auth_kwargs.get("auth_token")
|
||||
comfy_api_key = auth_kwargs.get("comfy_api_key")
|
||||
if auth_token:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
elif comfy_api_key:
|
||||
headers["X-API-KEY"] = comfy_api_key
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(await resp.read())
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(
|
||||
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
||||
).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = (
|
||||
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
)
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
|
||||
Args:
|
||||
file_bytes_io: BytesIO object containing the file data.
|
||||
filename: The filename of the file.
|
||||
upload_mime_type: MIME type of the file.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = await operation.execute()
|
||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
||||
return response.download_url
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: VideoInput,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = io.BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded video file.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.duration_seconds
|
||||
if actual_duration is not None and actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting video duration: {e}")
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = io.BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
max_images: Maximum number of images to upload.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
def resize_mask_to_image(
|
||||
mask: torch.Tensor,
|
||||
image: torch.Tensor,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled",
|
||||
allow_gradient=True,
|
||||
add_channel_dim=False,
|
||||
):
|
||||
"""
|
||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||
"""
|
||||
_, H, W, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1, 1)
|
||||
mask = common_upscale(
|
||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
||||
)
|
||||
mask = mask.movedim(1, -1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(
|
||||
image1: torch.Tensor, image2: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
@@ -1,17 +0,0 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import PixverseDto
|
||||
|
||||
|
||||
class ResponseData(BaseModel):
|
||||
ErrCode: Optional[int] = None
|
||||
ErrMsg: Optional[str] = None
|
||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
||||
@@ -1,57 +0,0 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class V2OpenAPII2VResp(BaseModel):
|
||||
video_id: Optional[int] = Field(None, description='Video_id')
|
||||
|
||||
|
||||
class V2OpenAPIT2VReq(BaseModel):
|
||||
aspect_ratio: str = Field(
|
||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
||||
)
|
||||
duration: int = Field(
|
||||
...,
|
||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
||||
examples=[5],
|
||||
)
|
||||
model: str = Field(
|
||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
||||
)
|
||||
motion_mode: Optional[str] = Field(
|
||||
'normal',
|
||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||
examples=['normal'],
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative prompt\n', max_length=2048
|
||||
)
|
||||
prompt: str = Field(..., description='Prompt', max_length=2048)
|
||||
quality: str = Field(
|
||||
...,
|
||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||
examples=['540p'],
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
||||
style: Optional[str] = Field(
|
||||
None,
|
||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
||||
examples=['anime'],
|
||||
)
|
||||
template_id: Optional[int] = Field(
|
||||
None,
|
||||
description='Template ID (template_id must be activated before use)',
|
||||
examples=[302325299692608],
|
||||
)
|
||||
water_mark: Optional[bool] = Field(
|
||||
False,
|
||||
description='Watermark (true: add watermark, false: no watermark)',
|
||||
examples=[False],
|
||||
)
|
||||
@@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
|
||||
|
||||
class BFLFluxCannyImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
||||
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxDepthImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
@@ -108,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel):
|
||||
# )
|
||||
|
||||
|
||||
class Flux2ProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
width: int = Field(1024, description="Must be a multiple of 32.")
|
||||
height: int = Field(768, description="Must be a multiple of 32.")
|
||||
seed: int | None = Field(None)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
safety_tolerance: int | None = Field(
|
||||
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
|
||||
)
|
||||
output_format: str | None = Field(
|
||||
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
|
||||
)
|
||||
|
||||
|
||||
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
@@ -147,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
|
||||
|
||||
class BFLFluxProGenerateResponse(BaseModel):
|
||||
id: str = Field(..., description='The unique identifier for the generation task.')
|
||||
polling_url: str = Field(..., description='URL to poll for the generation result.')
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
polling_url: str = Field(..., description="URL to poll for the generation result.")
|
||||
cost: float | None = Field(None, description="Price in cents")
|
||||
|
||||
|
||||
class BFLStatus(str, Enum):
|
||||
@@ -160,15 +146,8 @@ class BFLStatus(str, Enum):
|
||||
error = "Error"
|
||||
|
||||
|
||||
class BFLFluxProStatusResponse(BaseModel):
|
||||
class BFLFluxStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(
|
||||
None, description="The result of the task (null if not completed)."
|
||||
)
|
||||
progress: confloat(ge=0.0, le=1.0) = Field(
|
||||
..., description="The progress of the task (0.0 to 1.0)."
|
||||
)
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional details about the task (null if not available)."
|
||||
)
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||
|
||||
@@ -1,963 +0,0 @@
|
||||
"""
|
||||
API Client Framework for api.comfy.org.
|
||||
|
||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
||||
|
||||
Key Components:
|
||||
--------------
|
||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||||
3. ApiOperation - Executes a single synchronous API operation
|
||||
|
||||
Usage Examples:
|
||||
--------------
|
||||
|
||||
# Example 1: Synchronous API Operation
|
||||
# ------------------------------------
|
||||
# For a simple API call that returns the result immediately:
|
||||
|
||||
# 1. Create the API client
|
||||
api_client = ApiClient(
|
||||
base_url="https://api.example.com",
|
||||
auth_token="your_auth_token_here",
|
||||
comfy_api_key="your_comfy_api_key_here",
|
||||
timeout=30.0,
|
||||
verify_ssl=True
|
||||
)
|
||||
|
||||
# 2. Define the endpoint
|
||||
user_info_endpoint = ApiEndpoint(
|
||||
path="/v1/users/me",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest, # No request body needed
|
||||
response_model=UserProfile, # Pydantic model for the response
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 3. Create the request object
|
||||
request = EmptyRequest()
|
||||
|
||||
# 4. Create and execute the operation
|
||||
operation = ApiOperation(
|
||||
endpoint=user_info_endpoint,
|
||||
request=request
|
||||
)
|
||||
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
|
||||
|
||||
|
||||
# Example 2: Asynchronous API Operation with Polling
|
||||
# -------------------------------------------------
|
||||
# For an API that starts a task and requires polling for completion:
|
||||
|
||||
# 1. Define the endpoints (initial request and polling)
|
||||
generate_image_endpoint = ApiEndpoint(
|
||||
path="/v1/images/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=ImageGenerationRequest,
|
||||
response_model=TaskCreatedResponse,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
check_task_endpoint = ApiEndpoint(
|
||||
path="/v1/tasks/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=ImageGenerationResult,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 2. Create the request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt="a beautiful sunset over mountains",
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# 3. Create and execute the polling operation
|
||||
operation = PollingOperation(
|
||||
initial_endpoint=generate_image_endpoint,
|
||||
initial_request=request,
|
||||
poll_endpoint=check_task_endpoint,
|
||||
task_id_field="task_id",
|
||||
status_field="status",
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed", "error"]
|
||||
)
|
||||
|
||||
# This will make the initial request and then poll until completion
|
||||
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
|
||||
from enum import Enum
|
||||
import json
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid # For generating unique operation IDs
|
||||
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
from comfy import utils
|
||||
from . import request_logger
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
P = TypeVar("P", bound=BaseModel) # For poll response
|
||||
|
||||
PROGRESS_BAR_MAX = 100
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
pass
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
pass
|
||||
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Base class for empty request bodies.
|
||||
For GET requests, fields will be sent as query parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||
|
||||
|
||||
class HttpMethod(str, Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
"""
|
||||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
timeout: float = 3600.0,
|
||||
verify_ssl: bool = True,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
retry_status_codes: Optional[Tuple[int, ...]] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
||||
# 500, 502, 503, 504 (Server Errors)
|
||||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
||||
self._session: Optional[aiohttp.ClientSession] = session
|
||||
self._owns_session = session is None # Track if we have to close it
|
||||
|
||||
@staticmethod
|
||||
def _generate_operation_id(path: str) -> str:
|
||||
"""Generates a unique operation ID for logging."""
|
||||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@staticmethod
|
||||
def _create_json_payload_args(
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"json": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def _create_form_data_args(
|
||||
self,
|
||||
data: Dict[str, Any] | None,
|
||||
files: Dict[str, Any] | None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
multipart_parser: Callable | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
if headers and "Content-Type" in headers:
|
||||
del headers["Content-Type"]
|
||||
|
||||
if multipart_parser and data:
|
||||
data = multipart_parser(data)
|
||||
|
||||
if isinstance(data, aiohttp.FormData):
|
||||
form = data # If the parser already returned a FormData, pass it through
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if data: # regular text fields
|
||||
for k, v in data.items():
|
||||
if v is None:
|
||||
continue # aiohttp fails to serialize "None" values
|
||||
# aiohttp expects strings or bytes; convert enums etc.
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
|
||||
if files:
|
||||
file_iter = files if isinstance(files, list) else files.items()
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue # aiohttp fails to serialize "None" values
|
||||
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
|
||||
if isinstance(file_obj, tuple):
|
||||
filename, file_value, content_type = self._unpack_tuple(file_obj)
|
||||
else:
|
||||
file_value = file_obj
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
form.add_field(
|
||||
name=field_name,
|
||||
value=file_value,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
return {"data": form, "headers": headers or {}}
|
||||
|
||||
@staticmethod
|
||||
def _create_urlencoded_form_data_args(
|
||||
data: Dict[str, Any],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
headers = headers or {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
return {
|
||||
"data": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers for API requests, including authentication if available"""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
elif self.comfy_api_key:
|
||||
headers["X-API-KEY"] = self.comfy_api_key
|
||||
|
||||
return headers
|
||||
|
||||
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
|
||||
"""
|
||||
Check connectivity to determine if network issues are local or server-related.
|
||||
|
||||
Args:
|
||||
target_url: URL to check connectivity to
|
||||
|
||||
Returns:
|
||||
Dictionary with connectivity status details
|
||||
"""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
"is_local_issue": False,
|
||||
"is_api_issue": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
try:
|
||||
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
except (ClientError, asyncio.TimeoutError, socket.gaierror):
|
||||
results["is_local_issue"] = True
|
||||
return results # cannot reach the internet – early exit
|
||||
|
||||
# Now check API health endpoint
|
||||
parsed = urlparse(target_url)
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
try:
|
||||
async with session.get(health_url, ssl=self.verify_ssl) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
except ClientError:
|
||||
pass # leave as False
|
||||
|
||||
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
|
||||
return results
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable | None = None,
|
||||
retry_count: int = 0, # Used internally for tracking retries
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the API with automatic retries for transient errors.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: API endpoint path (will be joined with base_url)
|
||||
params: Query parameters
|
||||
data: body data
|
||||
files: Files to upload
|
||||
headers: Additional headers
|
||||
content_type: Content type of the request. Defaults to application/json.
|
||||
retry_count: Internal parameter for tracking retries, do not set manually
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
LocalNetworkError: If local network connectivity issues are detected
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
|
||||
# Build full URL and merge headers
|
||||
relative_path = path.lstrip("/")
|
||||
url = urljoin(self.base_url, relative_path)
|
||||
self._check_auth(self.auth_token, self.comfy_api_key)
|
||||
|
||||
request_headers = self.get_headers()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
if files:
|
||||
request_headers.pop("Content-Type", None)
|
||||
if params:
|
||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
||||
|
||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
||||
logging.debug(f"[DEBUG] Files: {files}")
|
||||
logging.debug(f"[DEBUG] Params: {params}")
|
||||
logging.debug(f"[DEBUG] Data: {data}")
|
||||
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
||||
elif content_type == "multipart/form-data":
|
||||
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
|
||||
else:
|
||||
payload_args = self._create_json_payload_args(data, request_headers)
|
||||
|
||||
operation_id = self._generate_operation_id(path)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=request_headers,
|
||||
request_params=params,
|
||||
request_data=data if content_type == "application/json" else "[form-data or other]",
|
||||
)
|
||||
|
||||
session = await self._get_session()
|
||||
try:
|
||||
async with session.request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
ssl=self.verify_ssl,
|
||||
**payload_args,
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
error_data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
error_data = await resp.text()
|
||||
|
||||
return await self._handle_http_error(
|
||||
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
|
||||
operation_id,
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
data,
|
||||
files,
|
||||
headers,
|
||||
content_type,
|
||||
multipart_parser,
|
||||
retry_count=retry_count,
|
||||
response_content=error_data,
|
||||
)
|
||||
|
||||
# Success – parse JSON (safely) and log
|
||||
try:
|
||||
payload = await resp.json()
|
||||
response_content_to_log = payload
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
payload = {}
|
||||
response_content_to_log = await resp.text()
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
return payload
|
||||
|
||||
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
|
||||
# Treat as *connection* problem – optionally retry, else escalate
|
||||
if retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
|
||||
self.max_retries, str(e))
|
||||
await asyncio.sleep(delay)
|
||||
return await self.request(
|
||||
method,
|
||||
path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
# One final connectivity check for diagnostics
|
||||
connectivity = await self._check_connectivity(self.base_url)
|
||||
if connectivity["is_local_issue"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError(
|
||||
f"The API server at {self.base_url} is currently unreachable. "
|
||||
f"The service may be experiencing issues. Please try again later."
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _check_auth(auth_token, comfy_api_key):
|
||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
||||
if auth_token is None and comfy_api_key is None:
|
||||
raise Exception("Unauthorized: Please login first to use this node.")
|
||||
return auth_token or comfy_api_key
|
||||
|
||||
@staticmethod
|
||||
async def upload_file(
|
||||
upload_url: str,
|
||||
file: io.BytesIO | str,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""Upload a file to the API with retry logic.
|
||||
|
||||
Args:
|
||||
upload_url: The URL to upload to
|
||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||||
content_type: Optional mime type to set for the upload
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
retry_backoff_factor: Multiplier for the delay after each retry
|
||||
"""
|
||||
headers: Dict[str, str] = {}
|
||||
skip_auto_headers: set[str] = set()
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
else:
|
||||
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
|
||||
skip_auto_headers.add("Content-Type")
|
||||
|
||||
# Extract file bytes
|
||||
if isinstance(file, io.BytesIO):
|
||||
file.seek(0)
|
||||
data = file.read()
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
raise ValueError("File must be BytesIO or str path")
|
||||
|
||||
parsed = urlparse(upload_url)
|
||||
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
|
||||
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
|
||||
delay = retry_delay
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.put(
|
||||
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
return resp
|
||||
except (ClientError, asyncio.TimeoutError) as e:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=e.status if hasattr(e, "status") else None,
|
||||
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
|
||||
response_content=None,
|
||||
error_message=f"{type(e).__name__}: {str(e)}",
|
||||
)
|
||||
if attempt < max_retries:
|
||||
logging.warning(
|
||||
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay *= retry_backoff_factor
|
||||
else:
|
||||
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
|
||||
|
||||
async def _handle_http_error(
|
||||
self,
|
||||
exc: ClientResponseError,
|
||||
operation_id: str,
|
||||
*req_meta,
|
||||
retry_count: int,
|
||||
response_content: dict | str = "",
|
||||
) -> Dict[str, Any]:
|
||||
status_code = exc.status
|
||||
if status_code == 401:
|
||||
user_friendly = "Unauthorized: Please login first to use this node."
|
||||
elif status_code == 402:
|
||||
user_friendly = "Payment Required: Please add credits to your account to use this node."
|
||||
elif status_code == 409:
|
||||
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
|
||||
elif status_code == 429:
|
||||
user_friendly = "Rate Limit Exceeded: Please try again later."
|
||||
else:
|
||||
if isinstance(response_content, dict):
|
||||
if "error" in response_content and "message" in response_content["error"]:
|
||||
user_friendly = f"API Error: {response_content['error']['message']}"
|
||||
if "type" in response_content["error"]:
|
||||
user_friendly += f" (Type: {response_content['error']['type']})"
|
||||
else: # Handle cases where error is just a JSON dict with unknown format
|
||||
user_friendly = f"API Error: {json.dumps(response_content)}"
|
||||
else:
|
||||
if len(response_content) < 200: # Arbitrary limit for display
|
||||
user_friendly = f"API Error (raw): {response_content}"
|
||||
else:
|
||||
user_friendly = f"API Error (raw, status {response_content})"
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=req_meta[0],
|
||||
request_url=req_meta[1],
|
||||
response_status_code=exc.status,
|
||||
response_headers=dict(req_meta[5]) if req_meta[5] else None,
|
||||
response_content=response_content,
|
||||
error_message=f"HTTP Error {exc.status}",
|
||||
)
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
|
||||
if response_content:
|
||||
logging.debug(f"[DEBUG] Response content: {response_content}")
|
||||
|
||||
# Retry if eligible
|
||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
"HTTP error %s. Retrying in %.2fs (%s/%s)",
|
||||
status_code,
|
||||
delay,
|
||||
retry_count + 1,
|
||||
self.max_retries,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
return await self.request(
|
||||
req_meta[0], # method
|
||||
req_meta[1].replace(self.base_url, ""), # path
|
||||
params=req_meta[2],
|
||||
data=req_meta[3],
|
||||
files=req_meta[4],
|
||||
headers=req_meta[5],
|
||||
content_type=req_meta[6],
|
||||
multipart_parser=req_meta[7],
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
raise Exception(user_friendly) from exc
|
||||
|
||||
@staticmethod
|
||||
def _unpack_tuple(t):
|
||||
"""Helper to normalise (filename, file, content_type) tuples."""
|
||||
if len(t) == 3:
|
||||
return t
|
||||
elif len(t) == 2:
|
||||
return t[0], t[1], "application/octet-stream"
|
||||
else:
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self._owns_session = True
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_session and self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self) -> "ApiClient":
|
||||
"""Allow usage as async‑context‑manager – ensures clean teardown"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
class ApiEndpoint(Generic[T, R]):
|
||||
"""Defines an API endpoint with its request and response types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: HttpMethod,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
query_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize an API endpoint definition.
|
||||
|
||||
Args:
|
||||
path: The URL path for this endpoint, can include placeholders like {id}
|
||||
method: The HTTP method to use (GET, POST, etc.)
|
||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||||
query_params: Optional dictionary of query parameters to include in the request
|
||||
"""
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.request_model = request_model
|
||||
self.response_model = response_model
|
||||
self.query_params = query_params or {}
|
||||
|
||||
|
||||
class SynchronousOperation(Generic[T, R]):
|
||||
"""Represents a single synchronous API operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: ApiEndpoint[T, R],
|
||||
request: T,
|
||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 7200.0,
|
||||
verify_ssl: bool = True,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
self.files = files
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.content_type = content_type
|
||||
self.multipart_parser = multipart_parser
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
|
||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
timeout=self.timeout,
|
||||
verify_ssl=self.verify_ssl,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
|
||||
try:
|
||||
request_dict: Optional[Dict[str, Any]]
|
||||
if isinstance(self.request, EmptyRequest):
|
||||
request_dict = None
|
||||
else:
|
||||
request_dict = self.request.model_dump(exclude_none=True)
|
||||
for k, v in list(request_dict.items()):
|
||||
if isinstance(v, Enum):
|
||||
request_dict[k] = v.value
|
||||
|
||||
logging.debug(
|
||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
||||
)
|
||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||
|
||||
response_json = await client.request(
|
||||
self.endpoint.method.value,
|
||||
self.endpoint.path,
|
||||
params=self.endpoint.query_params,
|
||||
data=request_dict,
|
||||
files=self.files,
|
||||
content_type=self.content_type,
|
||||
multipart_parser=self.multipart_parser,
|
||||
)
|
||||
|
||||
logging.debug("=" * 50)
|
||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
|
||||
logging.debug("=" * 50)
|
||||
|
||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
||||
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
|
||||
return parsed_response
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.close()
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Enum for task status values"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class PollingOperation(Generic[T, R]):
|
||||
"""Represents an asynchronous API operation that requires polling for completion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||||
completed_statuses: list[str],
|
||||
failed_statuses: list[str],
|
||||
status_extractor: Callable[[R], str],
|
||||
progress_extractor: Callable[[R], float] | None = None,
|
||||
result_url_extractor: Callable[[R], str] | None = None,
|
||||
request: Optional[T] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||||
max_retries: int = 3, # Max retries per individual API call
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
estimated_duration: Optional[float] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.poll_endpoint = poll_endpoint
|
||||
self.request = request
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.poll_interval = poll_interval
|
||||
self.max_poll_attempts = max_poll_attempts
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
self.estimated_duration = estimated_duration
|
||||
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
|
||||
self.progress_extractor = progress_extractor
|
||||
self.result_url_extractor = result_url_extractor
|
||||
self.node_id = node_id
|
||||
self.completed_statuses = completed_statuses
|
||||
self.failed_statuses = failed_statuses
|
||||
self.final_response: Optional[R] = None
|
||||
|
||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
try:
|
||||
return await self._poll_until_complete(client)
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.close()
|
||||
|
||||
def _display_text_on_node(self, text: str):
|
||||
if not self.node_id:
|
||||
return
|
||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
||||
|
||||
def _display_time_progress_on_node(self, time_completed: int | float):
|
||||
if not self.node_id:
|
||||
return
|
||||
if self.estimated_duration is not None:
|
||||
remaining = max(0, int(self.estimated_duration) - time_completed)
|
||||
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
|
||||
else:
|
||||
message = f"Task in progress: {time_completed}s"
|
||||
self._display_text_on_node(message)
|
||||
|
||||
def _check_task_status(self, response: R) -> TaskStatus:
|
||||
try:
|
||||
status = self.status_extractor(response)
|
||||
if status in self.completed_statuses:
|
||||
return TaskStatus.COMPLETED
|
||||
if status in self.failed_statuses:
|
||||
return TaskStatus.FAILED
|
||||
return TaskStatus.PENDING
|
||||
except Exception as e:
|
||||
logging.error("Error extracting status: %s", e)
|
||||
return TaskStatus.PENDING
|
||||
|
||||
async def _poll_until_complete(self, client: ApiClient) -> R:
|
||||
"""Poll until the task is complete"""
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
||||
|
||||
if self.progress_extractor:
|
||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||||
|
||||
status = TaskStatus.PENDING
|
||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
||||
try:
|
||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
||||
|
||||
request_dict = (
|
||||
None if self.request is None else self.request.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
if poll_count == 1:
|
||||
logging.debug(
|
||||
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
||||
)
|
||||
logging.debug(
|
||||
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
||||
)
|
||||
|
||||
# Query task status
|
||||
resp = await client.request(
|
||||
self.poll_endpoint.method.value,
|
||||
self.poll_endpoint.path,
|
||||
params=self.poll_endpoint.query_params,
|
||||
data=request_dict,
|
||||
)
|
||||
consecutive_errors = 0 # reset on success
|
||||
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
|
||||
|
||||
# Check if task is complete
|
||||
status = self._check_task_status(response_obj)
|
||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
||||
|
||||
# If progress extractor is provided, extract progress
|
||||
if self.progress_extractor:
|
||||
new_progress = self.progress_extractor(response_obj)
|
||||
if new_progress is not None:
|
||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||
|
||||
if status == TaskStatus.COMPLETED:
|
||||
message = "Task completed successfully"
|
||||
if self.result_url_extractor:
|
||||
result_url = self.result_url_extractor(response_obj)
|
||||
if result_url:
|
||||
message = f"Result URL: {result_url}"
|
||||
logging.debug(f"[DEBUG] {message}")
|
||||
self._display_text_on_node(message)
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
progress.update(100)
|
||||
return self.final_response
|
||||
if status == TaskStatus.FAILED:
|
||||
message = f"Task failed: {json.dumps(resp)}"
|
||||
logging.error(f"[DEBUG] {message}")
|
||||
raise Exception(message)
|
||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||
# Task pending – wait
|
||||
for i in range(int(self.poll_interval)):
|
||||
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except (LocalNetworkError, ApiServerError, NetworkError) as e:
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
||||
) from e
|
||||
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
except Exception as e:
|
||||
# For other errors, increment count and potentially abort
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
||||
logging.warning(
|
||||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
||||
f"Will retry in {self.poll_interval} seconds."
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
# If we've exhausted all polling attempts
|
||||
raise Exception(
|
||||
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
|
||||
"The operation may still be running on the server but is taking longer than expected."
|
||||
)
|
||||
@@ -1,19 +1,236 @@
|
||||
from __future__ import annotations
|
||||
from datetime import date
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
||||
from pydantic import BaseModel
|
||||
|
||||
class GeminiSafetyCategory(str, Enum):
|
||||
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
||||
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
||||
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
||||
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
||||
|
||||
|
||||
class GeminiSafetyThreshold(str, Enum):
|
||||
OFF = "OFF"
|
||||
BLOCK_NONE = "BLOCK_NONE"
|
||||
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
||||
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
||||
|
||||
|
||||
class GeminiSafetySetting(BaseModel):
|
||||
category: GeminiSafetyCategory
|
||||
threshold: GeminiSafetyThreshold
|
||||
|
||||
|
||||
class GeminiRole(str, Enum):
|
||||
user = "user"
|
||||
model = "model"
|
||||
|
||||
|
||||
class GeminiMimeType(str, Enum):
|
||||
application_pdf = "application/pdf"
|
||||
audio_mpeg = "audio/mpeg"
|
||||
audio_mp3 = "audio/mp3"
|
||||
audio_wav = "audio/wav"
|
||||
image_png = "image/png"
|
||||
image_jpeg = "image/jpeg"
|
||||
image_webp = "image/webp"
|
||||
text_plain = "text/plain"
|
||||
video_mov = "video/mov"
|
||||
video_mpeg = "video/mpeg"
|
||||
video_mp4 = "video/mp4"
|
||||
video_mpg = "video/mpg"
|
||||
video_avi = "video/avi"
|
||||
video_wmv = "video/wmv"
|
||||
video_mpegps = "video/mpegps"
|
||||
video_flv = "video/flv"
|
||||
|
||||
|
||||
class GeminiInlineData(BaseModel):
|
||||
data: str | None = Field(
|
||||
None,
|
||||
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
|
||||
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
|
||||
)
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiFileData(BaseModel):
|
||||
fileUri: str | None = Field(None)
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
parts: list[GeminiPart] = Field([])
|
||||
role: GeminiRole = Field(..., examples=["user"])
|
||||
|
||||
|
||||
class GeminiSystemInstructionContent(BaseModel):
|
||||
parts: list[GeminiTextPart] = Field(
|
||||
...,
|
||||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
description: str | None = Field(None)
|
||||
name: str = Field(...)
|
||||
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
|
||||
|
||||
|
||||
class GeminiTool(BaseModel):
|
||||
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
|
||||
|
||||
|
||||
class GeminiOffset(BaseModel):
|
||||
nanos: int | None = Field(None, ge=0, le=999999999)
|
||||
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
|
||||
|
||||
|
||||
class GeminiVideoMetadata(BaseModel):
|
||||
endOffset: GeminiOffset | None = Field(None)
|
||||
startOffset: GeminiOffset | None = Field(None)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModel):
|
||||
aspectRatio: str | None = Field(None)
|
||||
imageSize: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: Optional[List[str]] = None
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
imageConfig: GeminiImageConfig | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageGenerateContentRequest(BaseModel):
|
||||
contents: List[GeminiContent]
|
||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
||||
safetySettings: Optional[List[GeminiSafetySetting]] = None
|
||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
||||
contents: list[GeminiContent] = Field(...)
|
||||
generationConfig: GeminiImageGenerationConfig | None = Field(None)
|
||||
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
|
||||
|
||||
class GeminiGenerateContentRequest(BaseModel):
|
||||
contents: list[GeminiContent] = Field(...)
|
||||
generationConfig: GeminiGenerationConfig | None = Field(None)
|
||||
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
|
||||
|
||||
class Modality(str, Enum):
|
||||
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
|
||||
TEXT = "TEXT"
|
||||
IMAGE = "IMAGE"
|
||||
VIDEO = "VIDEO"
|
||||
AUDIO = "AUDIO"
|
||||
DOCUMENT = "DOCUMENT"
|
||||
|
||||
|
||||
class ModalityTokenCount(BaseModel):
|
||||
modality: Modality | None = None
|
||||
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
|
||||
|
||||
|
||||
class Probability(str, Enum):
|
||||
NEGLIGIBLE = "NEGLIGIBLE"
|
||||
LOW = "LOW"
|
||||
MEDIUM = "MEDIUM"
|
||||
HIGH = "HIGH"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
|
||||
|
||||
class GeminiSafetyRating(BaseModel):
|
||||
category: GeminiSafetyCategory | None = None
|
||||
probability: Probability | None = Field(
|
||||
None,
|
||||
description="The probability that the content violates the specified safety category",
|
||||
)
|
||||
|
||||
|
||||
class GeminiCitation(BaseModel):
|
||||
authors: list[str] | None = None
|
||||
endIndex: int | None = None
|
||||
license: str | None = None
|
||||
publicationDate: date | None = None
|
||||
startIndex: int | None = None
|
||||
title: str | None = None
|
||||
uri: str | None = None
|
||||
|
||||
|
||||
class GeminiCitationMetadata(BaseModel):
|
||||
citations: list[GeminiCitation] | None = None
|
||||
|
||||
|
||||
class GeminiCandidate(BaseModel):
|
||||
citationMetadata: GeminiCitationMetadata | None = None
|
||||
content: GeminiContent | None = None
|
||||
finishReason: str | None = None
|
||||
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||
|
||||
|
||||
class GeminiPromptFeedback(BaseModel):
|
||||
blockReason: str | None = None
|
||||
blockReasonMessage: str | None = None
|
||||
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||
|
||||
|
||||
class GeminiUsageMetadata(BaseModel):
|
||||
cachedContentTokenCount: int | None = Field(
|
||||
None,
|
||||
description="Output only. Number of tokens in the cached part in the input (the cached content).",
|
||||
)
|
||||
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
|
||||
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||
None, description="Breakdown of candidate tokens by modality."
|
||||
)
|
||||
promptTokenCount: int | None = Field(
|
||||
None,
|
||||
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
|
||||
)
|
||||
promptTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||
None, description="Breakdown of prompt tokens by modality."
|
||||
)
|
||||
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
|
||||
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
|
||||
|
||||
|
||||
class GeminiGenerateContentResponse(BaseModel):
|
||||
candidates: list[GeminiCandidate] | None = Field(None)
|
||||
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
||||
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
||||
modelVersion: str | None = Field(None)
|
||||
|
||||
66
comfy_api_nodes/apis/kling_api.py
Normal file
66
comfy_api_nodes/apis/kling_api.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OmniProText2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class OmniParamImage(BaseModel):
|
||||
image_url: str = Field(...)
|
||||
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
|
||||
|
||||
|
||||
class OmniParamVideo(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
|
||||
keep_original_sound: str = Field(..., description="'yes' or 'no'")
|
||||
|
||||
|
||||
class OmniProFirstLastFrameRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class OmniProReferences2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
image_list: list[OmniParamImage] | None = Field(
|
||||
None, max_length=7, description="Max length 4 when video is present."
|
||||
)
|
||||
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
|
||||
duration: str | None = Field(..., description="From 3 to 10.")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class TaskStatusVideoResult(BaseModel):
|
||||
duration: str | None = Field(None, description="Total video duration")
|
||||
id: str | None = Field(None, description="Generated video ID")
|
||||
url: str | None = Field(None, description="URL for generated video")
|
||||
|
||||
|
||||
class TaskStatusVideoResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusVideoResponseData(BaseModel):
|
||||
created_at: int | None = Field(None, description="Task creation time")
|
||||
updated_at: int | None = Field(None, description="Task update time")
|
||||
task_status: str | None = None
|
||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||
task_id: str | None = Field(None, description="Task ID")
|
||||
task_result: TaskStatusVideoResults | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusVideoResponse(BaseModel):
|
||||
code: int | None = Field(None, description="Error code")
|
||||
message: str | None = Field(None, description="Error message")
|
||||
request_id: str | None = Field(None, description="Request ID")
|
||||
data: TaskStatusVideoResponseData | None = Field(None)
|
||||
120
comfy_api_nodes/apis/minimax_api.py
Normal file
120
comfy_api_nodes/apis/minimax_api.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MinimaxBaseResponse(BaseModel):
|
||||
status_code: int = Field(
|
||||
...,
|
||||
description='Status code. 0 indicates success, other values indicate errors.',
|
||||
)
|
||||
status_msg: str = Field(
|
||||
..., description='Specific error details or success message.'
|
||||
)
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
bytes: Optional[int] = Field(None, description='File size in bytes')
|
||||
created_at: Optional[int] = Field(
|
||||
None, description='Unix timestamp when the file was created, in seconds'
|
||||
)
|
||||
download_url: Optional[str] = Field(
|
||||
None, description='The URL to download the video'
|
||||
)
|
||||
backup_download_url: Optional[str] = Field(
|
||||
None, description='The backup URL to download the video'
|
||||
)
|
||||
|
||||
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
|
||||
filename: Optional[str] = Field(None, description='The name of the file')
|
||||
purpose: Optional[str] = Field(None, description='The purpose of using the file')
|
||||
|
||||
|
||||
class MinimaxFileRetrieveResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
file: File
|
||||
|
||||
|
||||
class MiniMaxModel(str, Enum):
|
||||
T2V_01_Director = 'T2V-01-Director'
|
||||
I2V_01_Director = 'I2V-01-Director'
|
||||
S2V_01 = 'S2V-01'
|
||||
I2V_01 = 'I2V-01'
|
||||
I2V_01_live = 'I2V-01-live'
|
||||
T2V_01 = 'T2V-01'
|
||||
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||
|
||||
|
||||
class Status6(str, Enum):
|
||||
Queueing = 'Queueing'
|
||||
Preparing = 'Preparing'
|
||||
Processing = 'Processing'
|
||||
Success = 'Success'
|
||||
Fail = 'Fail'
|
||||
|
||||
|
||||
class MinimaxTaskResultResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
file_id: Optional[str] = Field(
|
||||
None,
|
||||
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
|
||||
)
|
||||
status: Status6 = Field(
|
||||
...,
|
||||
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
|
||||
)
|
||||
task_id: str = Field(..., description='The task ID being queried.')
|
||||
|
||||
|
||||
class SubjectReferenceItem(BaseModel):
|
||||
image: Optional[str] = Field(
|
||||
None, description='URL or base64 encoding of the subject reference image.'
|
||||
)
|
||||
mask: Optional[str] = Field(
|
||||
None,
|
||||
description='URL or base64 encoding of the mask for the subject reference image.',
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationRequest(BaseModel):
|
||||
callback_url: Optional[str] = Field(
|
||||
None,
|
||||
description='Optional. URL to receive real-time status updates about the video generation task.',
|
||||
)
|
||||
first_frame_image: Optional[str] = Field(
|
||||
None,
|
||||
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||
)
|
||||
model: MiniMaxModel = Field(
|
||||
...,
|
||||
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||
)
|
||||
prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
|
||||
max_length=2000,
|
||||
)
|
||||
prompt_optimizer: Optional[bool] = Field(
|
||||
True,
|
||||
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
|
||||
)
|
||||
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
|
||||
None,
|
||||
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||
)
|
||||
duration: Optional[int] = Field(
|
||||
None,
|
||||
description="The length of the output video in seconds."
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
task_id: str = Field(
|
||||
..., description='The task ID for the asynchronous video generation task.'
|
||||
)
|
||||
100
comfy_api_nodes/apis/pika_api.py
Normal file
100
comfy_api_nodes/apis/pika_api.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Pikaffect(str, Enum):
|
||||
Cake_ify = "Cake-ify"
|
||||
Crumble = "Crumble"
|
||||
Crush = "Crush"
|
||||
Decapitate = "Decapitate"
|
||||
Deflate = "Deflate"
|
||||
Dissolve = "Dissolve"
|
||||
Explode = "Explode"
|
||||
Eye_pop = "Eye-pop"
|
||||
Inflate = "Inflate"
|
||||
Levitate = "Levitate"
|
||||
Melt = "Melt"
|
||||
Peel = "Peel"
|
||||
Poke = "Poke"
|
||||
Squish = "Squish"
|
||||
Ta_da = "Ta-da"
|
||||
Tear = "Tear"
|
||||
|
||||
|
||||
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||
duration: Optional[int] = Field(5)
|
||||
ingredientsMode: str = Field(...)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = Field('1080p')
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaGenerateResponse(BaseModel):
|
||||
video_id: str = Field(...)
|
||||
|
||||
|
||||
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(
|
||||
1.7777777777777777,
|
||||
description='Aspect ratio (width / height)',
|
||||
ge=0.4,
|
||||
le=2.5,
|
||||
)
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
pikaffect: Optional[str] = None
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
modifyRegionRoi: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class PikaStatusEnum(str, Enum):
|
||||
queued = "queued"
|
||||
started = "started"
|
||||
finished = "finished"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class PikaVideoResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
status: PikaStatusEnum
|
||||
url: Optional[str] = Field(None)
|
||||
133
comfy_api_nodes/apis/topaz_api.py
Normal file
133
comfy_api_nodes/apis/topaz_api.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageEnhanceRequest(BaseModel):
|
||||
model: str = Field("Reimagine")
|
||||
output_format: str = Field("jpeg")
|
||||
subject_detection: str = Field("All")
|
||||
face_enhancement: bool = Field(True)
|
||||
face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false")
|
||||
face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false")
|
||||
source_url: str = Field(...)
|
||||
output_width: Optional[int] = Field(None)
|
||||
output_height: Optional[int] = Field(None)
|
||||
crop_to_fill: bool = Field(False)
|
||||
prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance")
|
||||
creativity: int = Field(3, description="Creativity settings range from 1 to 9")
|
||||
face_preservation: str = Field("true", description="To preserve the identity of characters")
|
||||
color_preservation: str = Field("true", description="To preserve the original color")
|
||||
|
||||
|
||||
class ImageAsyncTaskResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
|
||||
|
||||
class ImageStatusResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
credits: int = Field(...)
|
||||
|
||||
|
||||
class ImageDownloadResponse(BaseModel):
|
||||
download_url: str = Field(...)
|
||||
expiry: int = Field(...)
|
||||
|
||||
|
||||
class Resolution(BaseModel):
|
||||
width: int = Field(...)
|
||||
height: int = Field(...)
|
||||
|
||||
|
||||
class CreateCreateVideoRequestSource(BaseModel):
|
||||
container: str = Field(...)
|
||||
size: int = Field(..., description="Size of the video file in bytes")
|
||||
duration: int = Field(..., description="Duration of the video file in seconds")
|
||||
frameCount: int = Field(..., description="Total number of frames in the video")
|
||||
frameRate: int = Field(...)
|
||||
resolution: Resolution = Field(...)
|
||||
|
||||
|
||||
class VideoFrameInterpolationFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
slowmo: Optional[int] = Field(None)
|
||||
fps: int = Field(...)
|
||||
duplicate: bool = Field(...)
|
||||
duplicate_threshold: float = Field(...)
|
||||
|
||||
|
||||
class VideoEnhancementFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
auto: Optional[str] = Field(None, description="Auto, Manual, Relative")
|
||||
focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects")
|
||||
compression: Optional[float] = Field(None, description="Strength of compression recovery")
|
||||
details: Optional[float] = Field(None, description="Amount of detail reconstruction")
|
||||
prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing")
|
||||
noise: Optional[float] = Field(None, description="Amount of noise reduction")
|
||||
halo: Optional[float] = Field(None, description="Amount of halo reduction")
|
||||
preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength")
|
||||
blur: Optional[float] = Field(None, description="Amount of sharpness applied")
|
||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||
|
||||
|
||||
class OutputInformationVideo(BaseModel):
|
||||
resolution: Resolution = Field(...)
|
||||
frameRate: int = Field(...)
|
||||
audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert")
|
||||
audioTransfer: str = Field(..., description="Copy, Convert, None")
|
||||
dynamicCompressionLevel: str = Field(..., description="Low, Mid, High")
|
||||
|
||||
|
||||
class Overrides(BaseModel):
|
||||
isPaidDiffusion: bool = Field(True)
|
||||
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateCreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
|
||||
|
||||
class CreateVideoResponse(BaseModel):
|
||||
requestId: str = Field(...)
|
||||
|
||||
|
||||
class VideoAcceptResponse(BaseModel):
|
||||
uploadId: str = Field(...)
|
||||
urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequestPart(BaseModel):
|
||||
partNum: int = Field(...)
|
||||
eTag: str = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequest(BaseModel):
|
||||
uploadResults: list[VideoCompleteUploadRequestPart] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadResponse(BaseModel):
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
|
||||
|
||||
class VideoStatusResponseEstimates(BaseModel):
|
||||
cost: list[int] = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponseDownloadUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
estimates: Optional[VideoStatusResponseEstimates] = Field(None)
|
||||
progress: Optional[float] = Field(None)
|
||||
message: Optional[str] = Field("")
|
||||
download: Optional[VideoStatusResponseDownloadUrl] = Field(None)
|
||||
@@ -1,13 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoModelVersion,
|
||||
TripoTextureQuality,
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
v1_4_20240625 = 'v1.4-20240625'
|
||||
|
||||
|
||||
class TripoTextureQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
|
||||
|
||||
class TripoStyle(str, Enum):
|
||||
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||
ANIMAL_VENOM = "animal:venom"
|
||||
|
||||
99
comfy_api_nodes/apis/veo_api.py
Normal file
99
comfy_api_nodes/apis/veo_api.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VeoRequestInstanceImage(BaseModel):
|
||||
bytesBase64Encoded: str | None = Field(None)
|
||||
gcsUri: str | None = Field(None)
|
||||
mimeType: str | None = Field(None)
|
||||
|
||||
|
||||
class VeoRequestInstance(BaseModel):
|
||||
image: VeoRequestInstanceImage | None = Field(None)
|
||||
lastFrame: VeoRequestInstanceImage | None = Field(None)
|
||||
prompt: str = Field(..., description='Text description of the video')
|
||||
|
||||
|
||||
class VeoRequestParameters(BaseModel):
|
||||
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||
durationSeconds: Optional[int] = None
|
||||
enhancePrompt: Optional[bool] = None
|
||||
generateAudio: Optional[bool] = Field(
|
||||
None,
|
||||
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||
)
|
||||
negativePrompt: Optional[str] = None
|
||||
personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
|
||||
sampleCount: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
storageUri: Optional[str] = Field(
|
||||
None, description='Optional Cloud Storage URI to upload the video'
|
||||
)
|
||||
resolution: str | None = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidRequest(BaseModel):
|
||||
instances: list[VeoRequestInstance] | None = Field(None)
|
||||
parameters: VeoRequestParameters | None = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidResponse(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description='Operation resource name',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class VeoGenVidPollRequest(BaseModel):
|
||||
operationName: str = Field(
|
||||
...,
|
||||
description='Full operation name (from predict response)',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Video(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = Field(
|
||||
None, description='Base64-encoded video content'
|
||||
)
|
||||
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
|
||||
mimeType: Optional[str] = Field(None, description='Video MIME type')
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
code: Optional[int] = Field(None, description='Error code')
|
||||
message: Optional[str] = Field(None, description='Error message')
|
||||
|
||||
|
||||
class Response1(BaseModel):
|
||||
field_type: Optional[str] = Field(
|
||||
None,
|
||||
alias='@type',
|
||||
examples=[
|
||||
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
|
||||
],
|
||||
)
|
||||
raiMediaFilteredCount: Optional[int] = Field(
|
||||
None, description='Count of media filtered by responsible AI policies'
|
||||
)
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
done: Optional[bool] = None
|
||||
error: Optional[Error1] = Field(
|
||||
None, description='Error details if operation failed'
|
||||
)
|
||||
name: Optional[str] = None
|
||||
response: Optional[Response1] = Field(
|
||||
None, description='The actual prediction response if done is true'
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_as_bytesio,
|
||||
resize_mask_to_image,
|
||||
sync_op,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
@@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
@@ -233,89 +227,76 @@ async def download_and_process_images(image_urls):
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
def display_image_urls_on_node(image_urls, node_id):
|
||||
if node_id and image_urls:
|
||||
if len(image_urls) == 1:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
||||
)
|
||||
else:
|
||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
||||
)
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(comfy_io.ComfyNode):
|
||||
class IdeogramV1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -334,77 +315,63 @@ class IdeogramV1(comfy_io.ComfyNode):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV2(comfy_io.ComfyNode):
|
||||
class IdeogramV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=list(V1_V1_RES_MAP.keys()),
|
||||
default="Auto",
|
||||
@@ -412,44 +379,44 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"style_type",
|
||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
default="NONE",
|
||||
tooltip="Style type for generation (V2 only)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
#"color_palette": (
|
||||
@@ -462,12 +429,12 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
#),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -500,18 +467,11 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
@@ -519,36 +479,28 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
seed=seed,
|
||||
aspect_ratio=final_aspect_ratio,
|
||||
resolution=final_resolution,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
style_type=style_type if style_type != "NONE" else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV3(comfy_io.ComfyNode):
|
||||
class IdeogramV3(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="api node/image/Ideogram",
|
||||
@@ -556,30 +508,30 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or editing",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image for image editing.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Mask.Input(
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V3_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=V3_RESOLUTIONS,
|
||||
default="Auto",
|
||||
@@ -587,57 +539,57 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
"If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"rendering_speed",
|
||||
options=["DEFAULT", "TURBO", "QUALITY"],
|
||||
default="DEFAULT",
|
||||
tooltip="Controls the trade-off between generation speed and quality",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"character_image",
|
||||
tooltip="Image to use as character reference.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Mask.Input(
|
||||
IO.Mask.Input(
|
||||
"character_mask",
|
||||
tooltip="Optional mask for character reference image.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -656,10 +608,6 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
character_image=None,
|
||||
character_mask=None,
|
||||
):
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
if rendering_speed == "BALANCED": # for backward compatibility
|
||||
rendering_speed = "DEFAULT"
|
||||
|
||||
@@ -694,9 +642,6 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
||||
|
||||
# Process image and mask
|
||||
input_tensor = image.squeeze().cpu()
|
||||
# Resize mask to match image dimension
|
||||
@@ -749,27 +694,20 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
if character_mask_binary:
|
||||
files["character_mask_binary"] = character_mask_binary
|
||||
|
||||
# Execute the operation for edit mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3EditRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=edit_request,
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=edit_request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
# If only one of image or mask is provided, raise an error
|
||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||
else:
|
||||
# Generation mode
|
||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
||||
|
||||
# Create generation request
|
||||
gen_request = IdeogramV3Request(
|
||||
prompt=prompt,
|
||||
@@ -800,43 +738,34 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
if files:
|
||||
gen_request.style_type = "AUTO"
|
||||
|
||||
# Execute the operation for generation mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3Request,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=gen_request,
|
||||
files=files if files else None,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> IdeogramExtension:
|
||||
return IdeogramExtension()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
199
comfy_api_nodes/nodes_ltxv.py
Normal file
199
comfy_api_nodes/nodes_ltxv.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op_raw,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
MODELS_MAP = {
|
||||
"LTX-2 (Pro)": "ltx-2-pro",
|
||||
"LTX-2 (Fast)": "ltx-2-fast",
|
||||
}
|
||||
|
||||
|
||||
class ExecuteTaskRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiTextToVideo",
|
||||
display_name="LTXV Text To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||
raise ValueError(
|
||||
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||
)
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiImageToVideo",
|
||||
display_name="LTXV Image To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||
raise ValueError(
|
||||
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Currently only one input image is supported.")
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TextToVideoNode,
|
||||
ImageToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LtxvApiExtension:
|
||||
return LtxvApiExtension()
|
||||
@@ -1,75 +1,57 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
LumaImageModel,
|
||||
LumaVideoModel,
|
||||
LumaVideoOutputResolution,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaAspectRatio,
|
||||
LumaState,
|
||||
LumaImageGenerationRequest,
|
||||
LumaGenerationRequest,
|
||||
LumaGeneration,
|
||||
LumaCharacterRef,
|
||||
LumaModifyImageRef,
|
||||
LumaConceptChain,
|
||||
LumaGeneration,
|
||||
LumaGenerationRequest,
|
||||
LumaImageGenerationRequest,
|
||||
LumaImageIdentity,
|
||||
LumaImageModel,
|
||||
LumaImageReference,
|
||||
LumaIO,
|
||||
LumaKeyframes,
|
||||
LumaModifyImageRef,
|
||||
LumaReference,
|
||||
LumaReferenceChain,
|
||||
LumaImageReference,
|
||||
LumaKeyframes,
|
||||
LumaConceptChain,
|
||||
LumaIO,
|
||||
LumaVideoModel,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaVideoOutputResolution,
|
||||
get_luma_concepts,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
LUMA_T2V_AVERAGE_DURATION = 105
|
||||
LUMA_I2V_AVERAGE_DURATION = 100
|
||||
|
||||
def image_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
||||
|
||||
def video_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
||||
|
||||
class LumaReferenceNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Holds an image and weight for use with Luma Generate Image node.
|
||||
"""
|
||||
|
||||
class LumaReferenceNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaReferenceNode",
|
||||
display_name="Luma Reference",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Image to use as reference.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"weight",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
@@ -77,72 +59,56 @@ class LumaReferenceNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Weight of image reference.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
||||
IO.Custom(LumaIO.LUMA_REF).Input(
|
||||
"luma_ref",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||
) -> comfy_io.NodeOutput:
|
||||
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
|
||||
if luma_ref is not None:
|
||||
luma_ref = luma_ref.clone()
|
||||
else:
|
||||
luma_ref = LumaReferenceChain()
|
||||
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
||||
return comfy_io.NodeOutput(luma_ref)
|
||||
return IO.NodeOutput(luma_ref)
|
||||
|
||||
|
||||
class LumaConceptsNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
||||
"""
|
||||
|
||||
class LumaConceptsNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaConceptsNode",
|
||||
display_name="Luma Concepts",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept1",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept2",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept3",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept4",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to add to the ones chosen here.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -153,42 +119,38 @@ class LumaConceptsNode(comfy_io.ComfyNode):
|
||||
concept3: str,
|
||||
concept4: str,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
||||
if luma_concepts is not None:
|
||||
chain = luma_concepts.clone_and_merge(chain)
|
||||
return comfy_io.NodeOutput(chain)
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
class LumaImageGenerationNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode",
|
||||
display_name="Luma Text to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaImageModel],
|
||||
options=LumaImageModel,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[ratio.value for ratio in LumaAspectRatio],
|
||||
options=LumaAspectRatio,
|
||||
default=LumaAspectRatio.ratio_16_9,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -196,7 +158,7 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"style_image_weight",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
@@ -204,27 +166,27 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Weight of style image. Ignored if no style_image provided.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
||||
IO.Custom(LumaIO.LUMA_REF).Input(
|
||||
"image_luma_ref",
|
||||
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"style_image",
|
||||
tooltip="Style reference image; only 1 image will be used.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"character_image",
|
||||
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Image.Output()],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -237,45 +199,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
seed,
|
||||
style_image_weight: float,
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
image_luma_ref: Optional[LumaReferenceChain] = None,
|
||||
style_image: Optional[torch.Tensor] = None,
|
||||
character_image: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = await cls._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = await cls._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
|
||||
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=aspect_ratio,
|
||||
@@ -283,41 +230,21 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return comfy_io.NodeOutput(img)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||
|
||||
@classmethod
|
||||
async def _convert_luma_refs(
|
||||
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
|
||||
luma_urls.append(download_urls[0])
|
||||
ref_count += 1
|
||||
if ref_count >= max_refs:
|
||||
@@ -325,38 +252,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
@classmethod
|
||||
async def _convert_style_image(
|
||||
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
|
||||
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
|
||||
return await cls._convert_luma_refs(chain, max_refs=1)
|
||||
|
||||
|
||||
class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Modifies images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
class LumaImageModifyNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageModifyNode",
|
||||
display_name="Luma Image to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"image_weight",
|
||||
default=0.1,
|
||||
min=0.0,
|
||||
@@ -364,11 +283,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaImageModel],
|
||||
options=LumaImageModel,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -377,11 +296,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Image.Output()],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -394,99 +313,68 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
image: torch.Tensor,
|
||||
image_weight: float,
|
||||
seed,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# first, upload image
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
) -> IO.NodeOutput:
|
||||
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
modify_image_ref=LumaModifyImageRef(
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return comfy_io.NodeOutput(img)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||
|
||||
|
||||
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaVideoNode",
|
||||
display_name="Luma Text to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on prompt and output_size.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaVideoModel],
|
||||
options=LumaVideoModel,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[ratio.value for ratio in LumaAspectRatio],
|
||||
options=LumaAspectRatio,
|
||||
default=LumaAspectRatio.ratio_16_9,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
options=LumaVideoOutputResolution,
|
||||
default=LumaVideoOutputResolution.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
||||
options=LumaVideoModelOutputDuration,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -494,17 +382,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
luma_concepts: Optional[LumaConceptChain] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
@@ -545,77 +426,55 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on prompt, input images, and output_size.
|
||||
"""
|
||||
|
||||
class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageToVideoNode",
|
||||
display_name="Luma Image to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in LumaVideoModel],
|
||||
options=LumaVideoModel,
|
||||
),
|
||||
# comfy_io.Combo.Input(
|
||||
# IO.Combo.Input(
|
||||
# "aspect_ratio",
|
||||
# options=[ratio.value for ratio in LumaAspectRatio],
|
||||
# default=LumaAspectRatio.ratio_16_9,
|
||||
# ),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
options=LumaVideoOutputResolution,
|
||||
default=LumaVideoOutputResolution.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -623,27 +482,27 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"first_image",
|
||||
tooltip="First frame of generated video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"last_image",
|
||||
tooltip="Last frame of generated video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -660,27 +519,17 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
|
||||
raise Exception("At least one of first_image and last_image requires an input.")
|
||||
keyframes = await cls._convert_to_keyframes(first_image, last_image)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||
@@ -690,61 +539,38 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||
|
||||
@classmethod
|
||||
async def _convert_to_keyframes(
|
||||
cls,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
return None
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
LumaImageGenerationNode,
|
||||
LumaImageModifyNode,
|
||||
|
||||
@@ -1,71 +1,57 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
import logging
|
||||
import torch
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.minimax_api import (
|
||||
MinimaxFileRetrieveResponse,
|
||||
MiniMaxModel,
|
||||
MinimaxTaskResultResponse,
|
||||
MinimaxVideoGenerationRequest,
|
||||
MinimaxVideoGenerationResponse,
|
||||
MinimaxFileRetrieveResponse,
|
||||
MinimaxTaskResultResponse,
|
||||
SubjectReferenceItem,
|
||||
MiniMaxModel,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
I2V_AVERAGE_DURATION = 114
|
||||
T2V_AVERAGE_DURATION = 234
|
||||
|
||||
|
||||
async def _generate_mm_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
*,
|
||||
auth: dict[str, str],
|
||||
node_id: str,
|
||||
prompt_text: str,
|
||||
seed: int,
|
||||
model: str,
|
||||
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||
average_duration: Optional[int] = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
if image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
|
||||
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
data=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
@@ -73,95 +59,64 @@ async def _generate_mm_video(
|
||||
subject_reference=subject_reference,
|
||||
prompt_optimizer=None,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
task_result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=node_id,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=auth,
|
||||
file_result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info("Generated video URL: %s", file_url)
|
||||
if node_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, node_id)
|
||||
|
||||
# Download and return as VideoFromFile
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||
if file_result.file.backup_download_url:
|
||||
try:
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||
|
||||
|
||||
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxTextToVideoNode",
|
||||
display_name="MiniMax Text to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["T2V-01", "T2V-01-Director"],
|
||||
default="T2V-01",
|
||||
tooltip="Model to use for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -172,11 +127,11 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -187,13 +142,9 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
model: str = "T2V-01",
|
||||
seed: int = 0,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@@ -203,36 +154,32 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxImageToVideoNode",
|
||||
display_name="MiniMax Image to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Image to use as first frame of video generation",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
|
||||
default="I2V-01",
|
||||
tooltip="Model to use for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -243,11 +190,11 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -259,13 +206,9 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
model: str = "I2V-01",
|
||||
seed: int = 0,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@@ -275,36 +218,32 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxSubjectToVideoNode",
|
||||
display_name="MiniMax Subject to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"subject",
|
||||
tooltip="Image of subject to reference for video generation",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["S2V-01"],
|
||||
default="S2V-01",
|
||||
tooltip="Model to use for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -315,11 +254,11 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -331,13 +270,9 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
model: str = "S2V-01",
|
||||
seed: int = 0,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@@ -347,24 +282,22 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
||||
|
||||
class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxHailuoVideoNode",
|
||||
display_name="MiniMax Hailuo Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -374,25 +307,25 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"first_frame_image",
|
||||
tooltip="Optional image to use as the first frame to generate a video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"prompt_optimizer",
|
||||
default=True,
|
||||
tooltip="Optimize prompt to improve generation quality when needed.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=[6, 10],
|
||||
default=6,
|
||||
tooltip="The length of the output video in seconds.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["768P", "1080P"],
|
||||
default="768P",
|
||||
@@ -400,11 +333,11 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -419,11 +352,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
duration: int = 6,
|
||||
resolution: str = "768P",
|
||||
model: str = "MiniMax-Hailuo-02",
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
) -> IO.NodeOutput:
|
||||
if first_frame_image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
|
||||
@@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if first_frame_image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
|
||||
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
data=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
@@ -453,72 +379,47 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
average_duration = 120 if resolution == "768P" else 240
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
task_result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=auth,
|
||||
file_result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
if cls.hidden.unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
|
||||
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
if file_result.file.backup_download_url:
|
||||
try:
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||
|
||||
|
||||
class MinimaxExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
MinimaxTextToVideoNode,
|
||||
MinimaxImageToVideoNode,
|
||||
|
||||
@@ -1,33 +1,30 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
MoonvalleyPromptResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
|
||||
import av
|
||||
import io
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
trim_video,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||
@@ -50,13 +47,6 @@ MAX_VID_HEIGHT = 10000
|
||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||
|
||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class MoonvalleyApiError(Exception):
|
||||
"""Base exception for Moonvalley API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||
@@ -68,64 +58,7 @@ def validate_task_creation_response(response) -> None:
|
||||
if not is_valid_task_creation_response(response):
|
||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||
logging.error(error_msg)
|
||||
raise MoonvalleyApiError(error_msg)
|
||||
|
||||
|
||||
def get_video_from_response(response):
|
||||
video = response.output_url
|
||||
logging.info(
|
||||
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
|
||||
)
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Moonvalley video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response:
|
||||
return str(get_video_from_response(response))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
],
|
||||
max_poll_attempts=240, # 64 minutes with 16s interval
|
||||
poll_interval=16.0,
|
||||
failed_statuses=["error"],
|
||||
status_extractor=lambda response: (
|
||||
response.status if response and response.status else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
def validate_prompts(
|
||||
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
|
||||
):
|
||||
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
|
||||
if not prompt:
|
||||
raise ValueError("Positive prompt is empty")
|
||||
if len(prompt) > max_length:
|
||||
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
|
||||
if negative_prompt and len(negative_prompt) > max_length:
|
||||
raise ValueError(
|
||||
f"Negative prompt is too long: {len(negative_prompt)} characters"
|
||||
)
|
||||
return True
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
@@ -144,7 +77,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
"""
|
||||
width, height = _get_video_dimensions(video)
|
||||
_validate_video_dimensions(width, height)
|
||||
_validate_container_format(video)
|
||||
validate_container_format_is_mp4(video)
|
||||
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
@@ -169,21 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ", ".join(
|
||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_container_format(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(
|
||||
f"Only MP4 container format supported. Got: {container_format}"
|
||||
)
|
||||
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
@@ -196,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_minimum_duration(duration: float) -> None:
|
||||
"""Ensures video is at least 5 seconds long."""
|
||||
if duration < 5:
|
||||
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
@@ -206,127 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
return video
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream(
|
||||
"h264", rate=stream.average_rate
|
||||
)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info(
|
||||
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
|
||||
)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream(
|
||||
"aac", rate=stream.sample_rate
|
||||
)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info(
|
||||
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
||||
)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (
|
||||
estimated_frames // 16
|
||||
) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(
|
||||
f"Encoded {frame_count} video frames (target: {target_frames})"
|
||||
)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(f"Encoded {audio_frame_count} audio frames")
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def parse_width_height_from_res(resolution: str):
|
||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||
res_map = {
|
||||
@@ -335,7 +134,7 @@ def parse_width_height_from_res(resolution: str):
|
||||
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
||||
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
|
||||
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
||||
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||
}
|
||||
return res_map.get(resolution, {"width": 1920, "height": 1080})
|
||||
|
||||
@@ -350,52 +149,47 @@ def parse_control_parameter(value):
|
||||
return control_map.get(value, control_map["Motion Transfer"])
|
||||
|
||||
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=node_id,
|
||||
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
status_extractor=lambda r: (r.status if r and r.status else None),
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
|
||||
|
||||
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyImg2VideoNode",
|
||||
display_name="Moonvalley Marey Image to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="Moonvalley Marey Image to Video Node",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"16:9 (1920 x 1080)",
|
||||
@@ -403,42 +197,43 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
"1:1 (1152 x 1152)",
|
||||
"4:3 (1536 x 1152)",
|
||||
"3:4 (1152 x 1536)",
|
||||
"21:9 (2560 x 1080)",
|
||||
# "21:9 (2560 x 1080)",
|
||||
],
|
||||
default="16:9 (1920 x 1080)",
|
||||
tooltip="Resolution of the output video",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"prompt_adherence",
|
||||
default=10.0,
|
||||
default=4.5,
|
||||
min=1.0,
|
||||
max=20.0,
|
||||
step=1.0,
|
||||
tooltip="Guidance scale for generation control",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=100,
|
||||
default=33,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Number of denoising steps",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -453,22 +248,17 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
prompt_adherence: float,
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
guidance_scale=prompt_adherence,
|
||||
num_frames=128,
|
||||
width=width_height["width"],
|
||||
height=width_height["height"],
|
||||
use_negative_prompts=True,
|
||||
@@ -476,85 +266,69 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
|
||||
)
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_IMG2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return comfy_io.NodeOutput(video)
|
||||
return IO.NodeOutput(video)
|
||||
|
||||
|
||||
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyVideo2VideoNode",
|
||||
display_name="Moonvalley Marey Video to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Describes the video to generate",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=False,
|
||||
),
|
||||
comfy_io.Video.Input(
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
|
||||
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"control_type",
|
||||
options=["Motion Transfer", "Pose Transfer"],
|
||||
default="Motion Transfer",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"motion_intensity",
|
||||
default=100,
|
||||
min=0,
|
||||
@@ -563,12 +337,21 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
tooltip="Only used if control_type is 'Motion Transfer'",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Number of inference steps",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -582,16 +365,13 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
video: Optional[VideoInput] = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: Optional[int] = 100,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
|
||||
|
||||
validate_prompts(prompt, negative_prompt)
|
||||
video_url = await upload_video_to_comfyapi(cls, validated_video)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
@@ -602,65 +382,52 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
control_params=control_params,
|
||||
steps=steps,
|
||||
guidance_scale=prompt_adherence,
|
||||
)
|
||||
|
||||
control = parse_control_parameter(control_type)
|
||||
|
||||
request = MoonvalleyVideoToVideoRequest(
|
||||
control_type=control,
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_VIDEO2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyVideoToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyVideoToVideoRequest(
|
||||
control_type=parse_control_parameter(control_type),
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return comfy_io.NodeOutput(video)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyTxt2VideoNode",
|
||||
display_name="Moonvalley Marey Text to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"16:9 (1920 x 1080)",
|
||||
@@ -673,37 +440,38 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
default="16:9 (1920 x 1080)",
|
||||
tooltip="Resolution of the output video",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"prompt_adherence",
|
||||
default=10.0,
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
max=20.0,
|
||||
step=1.0,
|
||||
tooltip="Guidance scale for generation control",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Random seed value",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=100,
|
||||
default=33,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Inference steps",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -717,15 +485,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
prompt_adherence: float,
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
@@ -735,35 +499,21 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
width=width_height["width"],
|
||||
height=width_height["height"],
|
||||
)
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
|
||||
init_op = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_TXT2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
|
||||
)
|
||||
task_creation_response = await init_op.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return comfy_io.NodeOutput(video)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
MoonvalleyImg2VideoNode,
|
||||
MoonvalleyTxt2VideoNode,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,40 +7,23 @@ from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Optional, TypeVar
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||
from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
from comfy_api_nodes.apis import (
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
PikaGenerateResponse,
|
||||
PikaVideoResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
EmptyRequest,
|
||||
HttpMethod,
|
||||
PollingOperation,
|
||||
SynchronousOperation,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
@@ -55,152 +38,58 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
class PikaDurationEnum(int, Enum):
|
||||
integer_5 = 5
|
||||
integer_10 = 10
|
||||
|
||||
|
||||
class PikaResolutionEnum(str, Enum):
|
||||
field_1080p = "1080p"
|
||||
field_720p = "720p"
|
||||
|
||||
|
||||
class Pikaffect(str, Enum):
|
||||
Cake_ify = "Cake-ify"
|
||||
Crumble = "Crumble"
|
||||
Crush = "Crush"
|
||||
Decapitate = "Decapitate"
|
||||
Deflate = "Deflate"
|
||||
Dissolve = "Dissolve"
|
||||
Explode = "Explode"
|
||||
Eye_pop = "Eye-pop"
|
||||
Inflate = "Inflate"
|
||||
Levitate = "Levitate"
|
||||
Melt = "Melt"
|
||||
Peel = "Peel"
|
||||
Poke = "Poke"
|
||||
Squish = "Squish"
|
||||
Ta_da = "Ta-da"
|
||||
Tear = "Tear"
|
||||
|
||||
|
||||
class PikaApiError(Exception):
|
||||
"""Exception for Pika API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_video_response(response: PikaVideoResponse) -> bool:
|
||||
"""Check if the video response is valid."""
|
||||
return hasattr(response, "url") and response.url is not None
|
||||
|
||||
|
||||
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
|
||||
"""Check if the initial response is valid."""
|
||||
return hasattr(response, "video_id") and response.video_id is not None
|
||||
|
||||
|
||||
async def poll_for_task_status(
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> PikaGenerateResponse:
|
||||
polling_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PikaVideoResponse,
|
||||
),
|
||||
completed_statuses=[
|
||||
"finished",
|
||||
],
|
||||
failed_statuses=["failed", "cancelled"],
|
||||
status_extractor=lambda response: (
|
||||
response.status.value if response.status else None
|
||||
),
|
||||
progress_extractor=lambda response: (
|
||||
response.progress if hasattr(response, "progress") else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=lambda response: (
|
||||
response.url if hasattr(response, "url") else None
|
||||
),
|
||||
node_id=node_id,
|
||||
estimated_duration=60
|
||||
)
|
||||
return await polling_operation.execute()
|
||||
|
||||
|
||||
async def execute_task(
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
"""Executes the initial operation then polls for the task status until it is completed.
|
||||
|
||||
Args:
|
||||
initial_operation: The initial operation to execute.
|
||||
auth_kwargs: The authentication token(s) to use for the API call.
|
||||
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
"""
|
||||
initial_response = await initial_operation.execute()
|
||||
if not is_valid_initial_response(initial_response):
|
||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||
task_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
) -> IO.NodeOutput:
|
||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||
estimated_duration=60,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
if not final_response.url:
|
||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
)
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
video_url = str(final_response.url)
|
||||
raise Exception(error_msg)
|
||||
video_url = final_response.url
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
|
||||
return (await download_url_to_video_output(video_url),)
|
||||
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
||||
|
||||
|
||||
def get_base_inputs_types() -> list[comfy_io.Input]:
|
||||
def get_base_inputs_types() -> list[IO.Input]:
|
||||
"""Get the base required inputs types common to all Pika nodes."""
|
||||
return [
|
||||
comfy_io.String.Input("prompt_text", multiline=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution", options=[resolution.value for resolution in PikaResolutionEnum], default="1080p"
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"duration", options=[duration.value for duration in PikaDurationEnum], default=5
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||
IO.Combo.Input("duration", options=[5, 10], default=5),
|
||||
]
|
||||
|
||||
|
||||
class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
||||
class PikaImageToVideo(IO.ComfyNode):
|
||||
"""Pika 2.2 Image to Video Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaImageToVideoNode2_2",
|
||||
display_name="Pika Image to Video",
|
||||
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image", tooltip="The image to convert to video"),
|
||||
IO.Image.Input("image", tooltip="The image to convert to video"),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -214,53 +103,40 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
# Convert image to BytesIO
|
||||
) -> IO.NodeOutput:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
||||
class PikaTextToVideoNode(IO.ComfyNode):
|
||||
"""Pika Text2Video v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaTextToVideoNode2_2",
|
||||
display_name="Pika Text to Video",
|
||||
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
@@ -269,11 +145,11 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
)
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -287,19 +163,12 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@@ -307,30 +176,29 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaScenesV2_2(comfy_io.ComfyNode):
|
||||
class PikaScenes(IO.ComfyNode):
|
||||
"""PikaScenes v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaScenesV2_2",
|
||||
display_name="Pika Scenes (Video Image Composition)",
|
||||
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"ingredients_mode",
|
||||
options=["creative", "precise"],
|
||||
default="creative",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
@@ -338,37 +206,37 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_1",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_2",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_3",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_4",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_5",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -388,8 +256,7 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
# Convert all passed images to BytesIO
|
||||
) -> IO.NodeOutput:
|
||||
all_image_bytes_io = []
|
||||
for image in [
|
||||
image_ingredient_1,
|
||||
@@ -399,16 +266,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
||||
image_ingredient_5,
|
||||
]:
|
||||
if image is not None:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
all_image_bytes_io.append(image_bytes_io)
|
||||
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||
|
||||
pika_files = [
|
||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||
]
|
||||
|
||||
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||
ingredientsMode=ingredients_mode,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
@@ -417,53 +282,45 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASCENES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikAdditionsNode(comfy_io.ComfyNode):
|
||||
class PikAdditionsNode(IO.ComfyNode):
|
||||
"""Pika Pikadditions Node. Add an image into a video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikadditions",
|
||||
display_name="Pikadditions (Video Object Insertion)",
|
||||
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Video.Input("video", tooltip="The video to add an image to."),
|
||||
comfy_io.Image.Input("image", tooltip="The image to add to the video."),
|
||||
comfy_io.String.Input("prompt_text", multiline=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||
comfy_io.Int.Input(
|
||||
IO.Video.Input("video", tooltip="The video to add an image to."),
|
||||
IO.Image.Input("image", tooltip="The image to add to the video."),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -476,70 +333,70 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
# Convert video to BytesIO
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaSwapsNode(comfy_io.ComfyNode):
|
||||
class PikaSwapsNode(IO.ComfyNode):
|
||||
"""Pika Pikaswaps Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaswaps",
|
||||
display_name="Pika Swaps (Video Object Replacement)",
|
||||
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
|
||||
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."),
|
||||
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"),
|
||||
comfy_io.String.Input("prompt_text", multiline=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The image used to replace the masked object in the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Use the mask to define areas in the video to replace.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True, optional=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||
IO.String.Input(
|
||||
"region_to_modify",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Plaintext description of the object / region to modify.",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -548,85 +405,65 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
# Convert video to BytesIO
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
prompt_text: str = "",
|
||||
negative_prompt: str = "",
|
||||
seed: int = 0,
|
||||
region_to_modify: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
# Convert mask to binary mask with three channels
|
||||
mask = torch.round(mask)
|
||||
mask = mask.repeat(1, 3, 1, 1)
|
||||
|
||||
# Convert 3-channel binary mask to BytesIO
|
||||
mask_bytes_io = BytesIO()
|
||||
mask_bytes_io.write(mask.numpy().astype(np.uint8))
|
||||
mask_bytes_io.seek(0)
|
||||
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
||||
}
|
||||
if mask is not None:
|
||||
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||
if image is not None:
|
||||
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaffectsNode(comfy_io.ComfyNode):
|
||||
class PikaffectsNode(IO.ComfyNode):
|
||||
"""Pika Pikaffects Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaffects",
|
||||
display_name="Pikaffects (Video Effects)",
|
||||
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||
comfy_io.Combo.Input(
|
||||
"pikaffect", options=[pikaffect.value for pikaffect in Pikaffect], default="Cake-ify"
|
||||
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||
IO.Combo.Input(
|
||||
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||
),
|
||||
comfy_io.String.Input("prompt_text", multiline=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -639,19 +476,12 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
@@ -659,31 +489,30 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
"""PikaFrames v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaStartEndFrameNode2_2",
|
||||
display_name="Pika Start and End Frame to Video",
|
||||
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image_start", tooltip="The first image to combine."),
|
||||
comfy_io.Image.Input("image_end", tooltip="The last image to combine."),
|
||||
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
||||
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -698,23 +527,17 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
||||
pika_files = [
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFRAMES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@@ -723,22 +546,21 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaApiNodesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PikaImageToVideoV2_2,
|
||||
PikaTextToVideoNodeV2_2,
|
||||
PikaScenesV2_2,
|
||||
PikaImageToVideo,
|
||||
PikaTextToVideoNode,
|
||||
PikaScenes,
|
||||
PikAdditionsNode,
|
||||
PikaSwapsNode,
|
||||
PikaffectsNode,
|
||||
PikaStartEndFrameNode2_2,
|
||||
PikaStartEndFrameNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user