mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 19:20:03 +00:00
Compare commits
164 Commits
v3-process
...
update-tem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eeb6be4c2a | ||
|
|
336b2a7086 | ||
|
|
b75d349f25 | ||
|
|
7b8389578e | ||
|
|
9e00ce5b76 | ||
|
|
f5e66d5e47 | ||
|
|
87b0359392 | ||
|
|
cb96d4d18c | ||
|
|
394348f5ca | ||
|
|
7601e89255 | ||
|
|
6a1d3a1ae1 | ||
|
|
65ee24c978 | ||
|
|
17027f2a6a | ||
|
|
b5c8be8b1d | ||
|
|
24fdb92edf | ||
|
|
d526974576 | ||
|
|
e1ab6bb394 | ||
|
|
048f49adbd | ||
|
|
47bfd5a33f | ||
|
|
fdf49a2861 | ||
|
|
f41e5f398d | ||
|
|
27cbac865e | ||
|
|
3d0003c24c | ||
|
|
7d6103325e | ||
|
|
2d4a08b717 | ||
|
|
9a02382568 | ||
|
|
bd01d9f7fd | ||
|
|
443056c401 | ||
|
|
f60923590c | ||
|
|
1ef328c007 | ||
|
|
94c298f962 | ||
|
|
2fde9597f4 | ||
|
|
f91078b1ff | ||
|
|
3b3ef9a77a | ||
|
|
8b0b93df51 | ||
|
|
1c7eaeca10 | ||
|
|
18e7d6dba5 | ||
|
|
e1d85e7577 | ||
|
|
1199411747 | ||
|
|
5ebcab3c7d | ||
|
|
c350009236 | ||
|
|
dea899f221 | ||
|
|
e632e5de28 | ||
|
|
2abd2b5c20 | ||
|
|
a1a70362ca | ||
|
|
cf97b033ee | ||
|
|
eb1c42f649 | ||
|
|
e05c907126 | ||
|
|
09dc24c8a9 | ||
|
|
1d69245981 | ||
|
|
97f198e421 | ||
|
|
bda0eb2448 | ||
|
|
c4a6b389de | ||
|
|
4cd881866b | ||
|
|
265adad858 | ||
|
|
7f3e4d486c | ||
|
|
a389ee01bb | ||
|
|
9c71a66790 | ||
|
|
af4b7b5edb | ||
|
|
0f4ef3afa0 | ||
|
|
6b88478f9f | ||
|
|
e199c8cc67 | ||
|
|
0652cb8e2d | ||
|
|
958a17199a | ||
|
|
e974e554ca | ||
|
|
4e2110c794 | ||
|
|
e617cddf24 | ||
|
|
1f3f7a2823 | ||
|
|
88df172790 | ||
|
|
6d6a18b0b7 | ||
|
|
97ff9fae7e | ||
|
|
135fa49ec2 | ||
|
|
44869ff786 | ||
|
|
20182a393f | ||
|
|
5f109fe6a0 | ||
|
|
c58c13b2ba | ||
|
|
7f374e42c8 | ||
|
|
27d1bd8829 | ||
|
|
614cf9805e | ||
|
|
513b0c46fb | ||
|
|
dfac94695b | ||
|
|
163b629c70 | ||
|
|
998bf60beb | ||
|
|
906c089957 | ||
|
|
25de7b1bfa | ||
|
|
ab7ab5be23 | ||
|
|
ec4fc2a09a | ||
|
|
1a58087ac2 | ||
|
|
6c14f3afac | ||
|
|
e525673f72 | ||
|
|
3fa7a5c04a | ||
|
|
210f7a1ba5 | ||
|
|
d202c2ba74 | ||
|
|
8817f8fc14 | ||
|
|
22e40d2ace | ||
|
|
3bea4efc6b | ||
|
|
8cf2ba4ba6 | ||
|
|
b61a40cbc9 | ||
|
|
f2bb3230b7 | ||
|
|
614b8d3345 | ||
|
|
6abc30aae9 | ||
|
|
55bad30375 | ||
|
|
c305deed56 | ||
|
|
601ee1775a | ||
|
|
c170fd2db5 | ||
|
|
9d529e5308 | ||
|
|
f6bbc1ac84 | ||
|
|
098a352f13 | ||
|
|
e86b79ab9e | ||
|
|
426cde37f1 | ||
|
|
dd5af0c587 | ||
|
|
388b306a2b | ||
|
|
24188b3141 | ||
|
|
1bcda6df98 | ||
|
|
a1864c01f2 | ||
|
|
4739d7717f | ||
|
|
f13cff0be6 | ||
|
|
9cdc64998f | ||
|
|
560b1bdfca | ||
|
|
b7992f871a | ||
|
|
2c2aa409b0 | ||
|
|
a4787ac83b | ||
|
|
b5c59b763c | ||
|
|
b4f30bd408 | ||
|
|
dad076aee6 | ||
|
|
0cf33953a7 | ||
|
|
5b80addafd | ||
|
|
9da397ea2f | ||
|
|
92d97380bd | ||
|
|
99ce2a1f66 | ||
|
|
b1467da480 | ||
|
|
d8d60b5609 | ||
|
|
b1293d50ef | ||
|
|
19b466160c | ||
|
|
bc0ad9bb49 | ||
|
|
4054b4bf38 | ||
|
|
55ac7d333c | ||
|
|
afa8a24fe1 | ||
|
|
493b81e48f | ||
|
|
6b035bfce2 | ||
|
|
74b7f0b04b | ||
|
|
f72c6616b2 | ||
|
|
1c10b33f9b | ||
|
|
ddfce1af4f | ||
|
|
7a883849ea | ||
|
|
84867067ea | ||
|
|
3374e900d0 | ||
|
|
51696e3fdc | ||
|
|
dfff7e5332 | ||
|
|
e4ea393666 | ||
|
|
c8674bc6e9 | ||
|
|
3dfdcf66b6 | ||
|
|
95ca2e56c8 | ||
|
|
27ffd12c45 | ||
|
|
e693e4db6a | ||
|
|
d68ece7301 | ||
|
|
894837de9a | ||
|
|
fdc92863b6 | ||
|
|
a125cd84b0 | ||
|
|
84e9ce32c6 | ||
|
|
f43b8ab2a2 | ||
|
|
14d642acd6 | ||
|
|
aa895db7e8 | ||
|
|
cdfc25a160 |
@@ -0,0 +1,3 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
@@ -1,2 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -8,13 +8,15 @@ body:
|
||||
Before submitting a **Bug Report**, please ensure the following:
|
||||
|
||||
- **1:** You are running the latest version of ComfyUI.
|
||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||
`--disable-all-custom-nodes` command line argument.
|
||||
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||
|
||||
## API Node PR Checklist
|
||||
|
||||
### Scope
|
||||
- [ ] **Is API Node Change**
|
||||
|
||||
### Pricing & Billing
|
||||
- [ ] **Need pricing update**
|
||||
- [ ] **No pricing update**
|
||||
|
||||
If **Need pricing update**:
|
||||
- [ ] Metronome rate cards updated
|
||||
- [ ] Auto‑billing tests updated and passing
|
||||
|
||||
### QA
|
||||
- [ ] **QA done**
|
||||
- [ ] **QA not required**
|
||||
|
||||
### Comms
|
||||
- [ ] Informed **Kosinkadink**
|
||||
58
.github/workflows/api-node-template.yml
vendored
Normal file
58
.github/workflows/api-node-template.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Append API Node PR template
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, synchronize, ready_for_review]
|
||||
paths:
|
||||
- 'comfy_api_nodes/**' # only run if these files changed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
inject:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Ensure template exists and append to PR body
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { owner, repo } = context.repo;
|
||||
const number = context.payload.pull_request.number;
|
||||
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||
|
||||
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||
|
||||
let templateText;
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner,
|
||||
repo,
|
||||
path: templatePath,
|
||||
ref: pr.base.ref
|
||||
});
|
||||
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||
templateText = buf.toString('utf8');
|
||||
} catch (e) {
|
||||
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Enforce the presence of the marker inside the template (for idempotence)
|
||||
if (!templateText.includes(marker)) {
|
||||
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the PR already contains the marker, do not append again.
|
||||
const body = pr.body || '';
|
||||
if (body.includes(marker)) {
|
||||
core.info('Template already present in PR body; nothing to inject.');
|
||||
return;
|
||||
}
|
||||
|
||||
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||
core.notice('API Node template appended to PR description.');
|
||||
23
.github/workflows/release-stable-all.yml
vendored
23
.github/workflows/release-stable-all.yml
vendored
@@ -14,13 +14,13 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA Default (cu129)"
|
||||
name: "Release NVIDIA Default (cu130)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu129"
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "6"
|
||||
python_patch: "9"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
@@ -43,6 +43,23 @@ jobs:
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu126"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu126"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu126"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_amd_rocm:
|
||||
permissions:
|
||||
contents: "write"
|
||||
|
||||
20
.github/workflows/test-ci.yml
vendored
20
.github/workflows/test-ci.yml
vendored
@@ -21,14 +21,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
@@ -73,14 +74,15 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
|
||||
@@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "130"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@@ -0,0 +1,168 @@
|
||||
# The Comfy guide to Quantization
|
||||
|
||||
|
||||
## How does quantization work?
|
||||
|
||||
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||
|
||||
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||
|
||||
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||
|
||||
```
|
||||
absmax = max(abs(tensor))
|
||||
scale = amax / max_dynamic_range_low_precision
|
||||
|
||||
# Quantization
|
||||
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||
|
||||
# De-Quantization
|
||||
tensor_dq = tensor_q.to(fp16) * scale
|
||||
|
||||
tensor_dq ~ tensor
|
||||
```
|
||||
|
||||
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||
|
||||
|
||||
## Quantization in Comfy
|
||||
|
||||
```
|
||||
QuantizedTensor (torch.Tensor subclass)
|
||||
↓ __torch_dispatch__
|
||||
Two-Level Registry (generic + layout handlers)
|
||||
↓
|
||||
MixedPrecisionOps + Metadata Detection
|
||||
```
|
||||
|
||||
### Representation
|
||||
|
||||
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||
|
||||
A `Layout` class defines how a specific quantization format behaves:
|
||||
- Required parameters
|
||||
- Quantize method
|
||||
- De-Quantize method
|
||||
|
||||
```python
|
||||
from comfy.quant_ops import QuantizedLayout
|
||||
|
||||
class MyLayout(QuantizedLayout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs):
|
||||
# Convert to quantized format
|
||||
qdata = ...
|
||||
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
return qdata.to(orig_dtype) * scale
|
||||
```
|
||||
|
||||
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||
|
||||
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||
```python
|
||||
from comfy.quant_ops import register_layout_op
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||
def my_linear(func, args, kwargs):
|
||||
# Extract tensors, call optimized kernel
|
||||
...
|
||||
```
|
||||
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||
|
||||
**Architecture:**
|
||||
|
||||
```python
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||
```
|
||||
|
||||
**Key mechanism:**
|
||||
|
||||
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||
- If the layer name **is** in `_layer_quant_config`:
|
||||
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||
- Load associated quantization parameters (scales, block_size, etc.)
|
||||
|
||||
**Why it's needed:**
|
||||
|
||||
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||
|
||||
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||
|
||||
|
||||
## Checkpoint Format
|
||||
|
||||
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||
|
||||
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||
|
||||
### Scaling Parameters details
|
||||
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||
- **weight_scale**: quantization scalers for the weights
|
||||
- **weight_scale_2**: global scalers in the context of double scaling
|
||||
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||
- **input_scale**: quantization scalers for the activations
|
||||
|
||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||
|
||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||
|
||||
### Quantization Metadata
|
||||
|
||||
The metadata stored alongside the checkpoint contains:
|
||||
- **format_version**: String to define a version of the standard
|
||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"_quantization_metadata": {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Creating Quantized Checkpoints
|
||||
|
||||
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||
|
||||
### Weight Quantization
|
||||
|
||||
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||
|
||||
### Calibration (for Activation Quantization)
|
||||
|
||||
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||
|
||||
1. **Collect statistics**: Run inference on N representative samples
|
||||
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
29
README.md
29
README.md
@@ -112,10 +112,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
@@ -172,15 +173,19 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
@@ -197,7 +202,11 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@@ -214,7 +223,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
@@ -235,7 +244,7 @@ RDNA 4 (RX 9000 series):
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
@@ -245,15 +254,11 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@ import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from typing import Dict, TypedDict, Optional
|
||||
from aiohttp import web
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
@@ -257,7 +258,54 @@ comfyui-frontend-package is not installed.
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
||||
"""Return a mapping of template asset names to their absolute paths."""
|
||||
try:
|
||||
from comfyui_workflow_templates import (
|
||||
get_asset_path,
|
||||
iter_templates,
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
template_entries = list(iter_templates())
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
||||
return None
|
||||
|
||||
asset_map: Dict[str, str] = {}
|
||||
try:
|
||||
for entry in template_entries:
|
||||
for asset in entry.assets:
|
||||
asset_map[asset.filename] = get_asset_path(
|
||||
entry.template_id, asset.filename
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
||||
return None
|
||||
|
||||
if not asset_map:
|
||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
||||
return None
|
||||
|
||||
return asset_map
|
||||
|
||||
|
||||
@classmethod
|
||||
def legacy_templates_path(cls) -> Optional[str]:
|
||||
"""Return the legacy templates directory shipped inside the meta package."""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@@ -276,6 +324,7 @@ comfyui-workflow-templates is not installed.
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
@@ -392,3 +441,17 @@ comfyui-workflow-templates is not installed.
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
assets = cls.template_asset_map()
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
||||
rel_path = request.match_info.get("path", "")
|
||||
target = assets.get(rel_path)
|
||||
if target is None:
|
||||
raise web.HTTPNotFound()
|
||||
return web.FileResponse(target)
|
||||
|
||||
return serve_template
|
||||
|
||||
112
app/subgraph_manager.py
Normal file
112
app/subgraph_manager.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import hashlib
|
||||
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
"""
|
||||
Source of subgraph - custom_nodes vs templates.
|
||||
"""
|
||||
path: str
|
||||
"""
|
||||
Relative path of the subgraph file.
|
||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||
"""
|
||||
name: str
|
||||
"""
|
||||
Name of subgraph file.
|
||||
"""
|
||||
info: CustomNodeSubgraphEntryInfo
|
||||
"""
|
||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||
"""
|
||||
data: str
|
||||
|
||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
node_pack: str
|
||||
"""Node pack name."""
|
||||
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||
if entry is None:
|
||||
return None
|
||||
entry = entry.copy()
|
||||
entry.pop('path', None)
|
||||
if remove_data:
|
||||
entry.pop('data', None)
|
||||
return entry
|
||||
|
||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||
entries = entries.copy()
|
||||
for key in list(entries.keys()):
|
||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
@@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@@ -145,7 +146,9 @@ class PerformanceFeature(enum.Enum):
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
@@ -310,11 +310,13 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
@@ -350,12 +352,13 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
# TODO: remove this in a few months
|
||||
SingleStreamBlock = None
|
||||
DoubleStreamBlock = None
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@@ -48,124 +48,6 @@ class Approximator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -11,12 +11,12 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
@@ -90,6 +90,7 @@ class Chroma(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -98,7 +99,7 @@ class Chroma(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -10,12 +10,10 @@ from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
@@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
@@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
@@ -189,15 +188,15 @@ class ChromaRadiance(Chroma):
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
@@ -240,17 +239,8 @@ class ChromaRadiance(Chroma):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||
|
||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
|
||||
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
@@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.modulation = modulation
|
||||
|
||||
if self.modulation:
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@@ -147,7 +151,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@@ -160,46 +166,65 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
del img_attn
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
del txt_attn
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
@@ -220,6 +245,7 @@ class SingleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
modulation=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@@ -242,19 +268,29 @@ class SingleStreamBlock(nn.Module):
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.modulation = None
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
if self.modulation:
|
||||
mod, _ = self.modulation(vec)
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
@@ -7,15 +7,8 @@ import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
q, k = apply_rope(q, k, pe)
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
@@ -210,7 +210,7 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@@ -222,10 +222,22 @@ class Flux(nn.Module):
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
steps_h = h_len
|
||||
steps_w = w_len
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
index += rope_options.get("shift_t", 0.0)
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
@@ -241,7 +253,7 @@ class Flux(nn.Module):
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
|
||||
@@ -3,12 +3,11 @@ from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@@ -238,20 +237,6 @@ class FeedForward(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
@@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
x.addcmul_(self.ff(y), gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
@@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
dtype = torch.float32
|
||||
device = indices_grid.device
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
# Compute frequencies and apply cos/sin
|
||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
# Pad if dim is not divisible by 6
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
padding_size = dim % 6
|
||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||
|
||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
|
||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||
freqs_cis = torch.stack([
|
||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
@@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = torch.addcmul(x, x, scale).add_(shift)
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
|
||||
@@ -522,7 +522,7 @@ class NextDiT(nn.Module):
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
@@ -531,10 +531,22 @@ class NextDiT(nn.Module):
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
|
||||
|
||||
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
import comfy.model_management
|
||||
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
'''
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import comfy.model_management
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
else:
|
||||
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/core.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def sinc(x: torch.Tensor):
|
||||
"""
|
||||
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||
"""
|
||||
return torch.where(x == 0,
|
||||
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x)
|
||||
|
||||
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = (kernel_size % 2 == 0)
|
||||
half_size = kernel_size // 2
|
||||
|
||||
#For kaiser window
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.:
|
||||
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||
else:
|
||||
beta = 0.
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||
if even:
|
||||
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||
# of the constant component in the input signal.
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = 'replicate',
|
||||
kernel_size: int = 12):
|
||||
# kernel_size should be even number for stylegan3 setup,
|
||||
# in this implementation, odd number is also possible.
|
||||
super().__init__()
|
||||
if cutoff < -0.:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = (kernel_size % 2 == 0)
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
#input [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||
mode=self.padding_mode)
|
||||
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||
stride=self.stride, groups=C)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left:-self.pad_right]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.lowpass(x)
|
||||
|
||||
return xx
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(self,
|
||||
activation,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12):
|
||||
super().__init__()
|
||||
self.up_ratio = up_ratio
|
||||
self.down_ratio = down_ratio
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
# x: [B,C,T]
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
|
||||
return x
|
||||
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
from .vae import VAE_16k
|
||||
from .bigvgan import BigVGANVocoder
|
||||
import logging
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
||||
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes, norm_fn):
|
||||
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
||||
return output
|
||||
|
||||
class MelConverter(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sampling_rate: float,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
norm_fn,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
# mel = librosa_mel_fn(sr=self.sampling_rate,
|
||||
# n_fft=self.n_fft,
|
||||
# n_mels=self.num_mels,
|
||||
# fmin=self.fmin,
|
||||
# fmax=self.fmax)
|
||||
# mel_basis = torch.from_numpy(mel).float()
|
||||
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
|
||||
hann_window = torch.hann_window(self.win_size)
|
||||
|
||||
self.register_buffer('mel_basis', mel_basis)
|
||||
self.register_buffer('hann_window', hann_window)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.mel_basis.device
|
||||
|
||||
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
|
||||
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
|
||||
|
||||
waveform = torch.nn.functional.pad(
|
||||
waveform.unsqueeze(1),
|
||||
[int((self.n_fft - self.hop_size) / 2),
|
||||
int((self.n_fft - self.hop_size) / 2)],
|
||||
mode='reflect')
|
||||
waveform = waveform.squeeze(1)
|
||||
|
||||
spec = torch.stft(waveform,
|
||||
self.n_fft,
|
||||
hop_length=self.hop_size,
|
||||
win_length=self.win_size,
|
||||
window=self.hann_window,
|
||||
center=center,
|
||||
pad_mode='reflect',
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True)
|
||||
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
spec = torch.matmul(self.mel_basis, spec)
|
||||
spec = spectral_normalize_torch(spec, self.norm_fn)
|
||||
|
||||
return spec
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# ckpt_path: str,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode == "16k", "Only 16k mode is supported currently."
|
||||
self.mel_converter = MelConverter(sampling_rate=16_000,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=8_000,
|
||||
norm_fn=torch.log10)
|
||||
|
||||
self.vae = VAE_16k().eval()
|
||||
|
||||
bigvgan_config = {
|
||||
"resblock": "1",
|
||||
"num_mels": 80,
|
||||
"upsample_rates": [4, 4, 2, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
],
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": True,
|
||||
}
|
||||
|
||||
self.vocoder = BigVGANVocoder(
|
||||
bigvgan_config
|
||||
).eval()
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
||||
# x: (B * L)
|
||||
mel = self.mel_converter(x)
|
||||
dist = self.vae.encode(mel)
|
||||
|
||||
return dist
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, z):
|
||||
mel_decoded = self.vae.decode(z)
|
||||
audio = self.vocoder(mel_decoded)
|
||||
|
||||
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
||||
return audio
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audio):
|
||||
audio = audio.mean(dim=1)
|
||||
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
||||
dist = self.encode_audio(audio)
|
||||
return dist.mean
|
||||
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from . import activations
|
||||
from .alias_free_torch import Activation1d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||
super(AMPBlock1, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]))
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AMPBlock2(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
||||
super(AMPBlock2, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BigVGANVocoder(torch.nn.Module):
|
||||
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||
def __init__(self, h):
|
||||
super().__init__()
|
||||
if isinstance(h, dict):
|
||||
h = SimpleNamespace(**h)
|
||||
self.h = h
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# pre conv
|
||||
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
|
||||
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||
|
||||
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList([
|
||||
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2)
|
||||
]))
|
||||
|
||||
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# post conv
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# pre conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
# upsampling
|
||||
for i_up in range(len(self.ups[i])):
|
||||
x = self.ups[i][i_up](x)
|
||||
# AMP blocks
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
# post conv
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
358
comfy/ldm/mmaudio/vae/vae.py
Normal file
358
comfy/ldm/mmaudio/vae/vae.py
Normal file
@@ -0,0 +1,358 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
Upsample1D, nonlinearity)
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
DATA_MEAN_80D = [
|
||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
|
||||
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
|
||||
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
||||
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
|
||||
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
|
||||
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
|
||||
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
|
||||
]
|
||||
|
||||
DATA_STD_80D = [
|
||||
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
|
||||
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
|
||||
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
|
||||
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
|
||||
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
|
||||
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
|
||||
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
|
||||
]
|
||||
|
||||
DATA_MEAN_128D = [
|
||||
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
|
||||
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
|
||||
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
|
||||
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
||||
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
|
||||
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
|
||||
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
|
||||
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
||||
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
|
||||
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
|
||||
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
|
||||
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
||||
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
|
||||
]
|
||||
|
||||
DATA_STD_128D = [
|
||||
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
|
||||
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
|
||||
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
|
||||
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
|
||||
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
|
||||
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
|
||||
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
|
||||
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
|
||||
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
|
||||
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
|
||||
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
|
||||
]
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_dim: int,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_dim == 80:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
||||
elif data_dim == 128:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
||||
|
||||
self.data_mean = self.data_mean.view(1, -1, 1)
|
||||
self.data_std = self.data_std.view(1, -1, 1)
|
||||
|
||||
self.encoder = Encoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
self.decoder = Decoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
out_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
|
||||
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
pass
|
||||
|
||||
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
|
||||
if normalize:
|
||||
x = self.normalize(x)
|
||||
moments = self.encoder(x)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
|
||||
dec = self.decoder(z)
|
||||
if unnormalize:
|
||||
dec = self.unnormalize(dec)
|
||||
return dec
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
|
||||
|
||||
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
rng: Optional[torch.Generator] = None,
|
||||
normalize: bool = True,
|
||||
unnormalize: bool = True,
|
||||
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||
|
||||
posterior = self.encode(x, normalize=normalize)
|
||||
if sample_posterior:
|
||||
z = posterior.sample(rng)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, unnormalize=unnormalize)
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def remove_weight_norm(self):
|
||||
return self
|
||||
|
||||
|
||||
class Encoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
double_z: bool = True,
|
||||
kernel_size: int = 3,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = down_layers
|
||||
self.attn_layers = attn_layers
|
||||
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
# downsampling
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_layers):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = dim * in_ch_mult[i_level]
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_out,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level in down_layers:
|
||||
down.downsample = Downsample1D(block_in, resamp_with_conv)
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
|
||||
# end
|
||||
self.conv_out = ops.Conv1d(block_in,
|
||||
2 * embed_dim if double_z else embed_dim,
|
||||
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_layers):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
if i_level in self.down_layers:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# end
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
kernel_size: int = 3,
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.ch = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = dim * ch_mult[self.num_layers - 1]
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level in self.down_layers:
|
||||
up.upsample = Upsample1D(block_in, resamp_with_conv)
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
if i_level in self.down_layers:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||
return h
|
||||
|
||||
|
||||
def VAE_16k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
|
||||
|
||||
|
||||
def VAE_44k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
|
||||
|
||||
|
||||
def get_my_vae(name: str, **kwargs) -> VAE:
|
||||
if name == '16k':
|
||||
return VAE_16k(**kwargs)
|
||||
if name == '44k':
|
||||
return VAE_44k(**kwargs)
|
||||
raise ValueError(f'Unknown model: {name}')
|
||||
|
||||
121
comfy/ldm/mmaudio/vae/vae_modules.py
Normal file
121
comfy/ldm/mmaudio/vae/vae_modules.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return torch.nn.functional.silu(x) / 0.596
|
||||
|
||||
def mp_sum(a, b, t=0.5):
|
||||
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
|
||||
|
||||
def normalize(x, dim=None, eps=1e-4):
|
||||
if dim is None:
|
||||
dim = list(range(1, x.ndim))
|
||||
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||
return x / norm.to(x.dtype)
|
||||
|
||||
class ResnetBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
out_dim = in_dim if out_dim is None else out_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.use_norm = use_norm
|
||||
|
||||
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
else:
|
||||
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# pixel norm
|
||||
if self.use_norm:
|
||||
x = normalize(x, dim=1)
|
||||
|
||||
h = x
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class AttnBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, num_heads=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
|
||||
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
y = self.qkv(h)
|
||||
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
|
||||
q, k, v = normalize(y, dim=1).unbind(2)
|
||||
|
||||
h = self.optimized_attention(q, k, v)
|
||||
h = self.proj_out(h)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv1(x)
|
||||
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
@@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||
|
||||
@@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@@ -134,33 +135,34 @@ class Attention(nn.Module):
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@@ -234,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
del img_mod1
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
del txt_mod1
|
||||
|
||||
img_attn_output, txt_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
@@ -246,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del img_modulated
|
||||
del txt_modulated
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
del img_attn_output
|
||||
del txt_attn_output
|
||||
del img_gate1
|
||||
del txt_gate1
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
@@ -413,7 +419,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
@@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
@@ -588,7 +589,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@@ -601,10 +602,22 @@ class WanModel(torch.nn.Module):
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
@@ -630,7 +643,7 @@ class WanModel(torch.nn.Module):
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
|
||||
@@ -657,51 +657,51 @@ class WanVAE(nn.Module):
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
self.clear_cache()
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
x = patchify(x, patch_size=2)
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
self.clear_cache()
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
first_chunk=True,
|
||||
)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
@@ -715,12 +715,3 @@ class WanVAE(nn.Module):
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
@@ -134,10 +134,11 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
self.diffusion_model.eval()
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
@@ -196,8 +197,14 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
return timestep
|
||||
@@ -326,6 +333,14 @@ class BaseModel(torch.nn.Module):
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": self.model_config.layer_quant_config
|
||||
}
|
||||
unet_state_dict["_quantization_metadata"] = metadata
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@@ -669,7 +684,6 @@ class Lotus(BaseModel):
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@@ -698,7 +712,6 @@ class StableCascade_C(BaseModel):
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
@@ -6,6 +6,20 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@@ -213,7 +227,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 32
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
else:
|
||||
@@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
|
||||
@@ -89,6 +89,7 @@ if args.deterministic:
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
|
||||
import torch_directml
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
@@ -330,13 +331,21 @@ except:
|
||||
|
||||
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
@@ -344,11 +353,11 @@ try:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
# if torch_version_numeric >= (2, 8):
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
SUPPORT_FP8_OPS = True
|
||||
|
||||
except:
|
||||
@@ -370,6 +379,9 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
@@ -492,6 +504,7 @@ class LoadedModel:
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
|
||||
real_model = self.model.model
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
@@ -677,7 +690,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
lowvram_model_memory = 0.1
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 0.1
|
||||
@@ -925,11 +941,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
if d == torch.float16 and should_use_fp16(device):
|
||||
return d
|
||||
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
# also a problem on RDNA4 except fp32 is also slow there.
|
||||
# This is due to large bf16 convolutions being extremely slow.
|
||||
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
||||
if d == torch.bfloat16 and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
return torch.float32
|
||||
@@ -991,12 +1003,6 @@ def device_supports_non_blocking(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
@@ -1011,6 +1017,16 @@ if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
def current_stream(device):
|
||||
if device is None:
|
||||
return None
|
||||
if is_device_cuda(device):
|
||||
return torch.cuda.current_stream()
|
||||
elif is_device_xpu(device):
|
||||
return torch.xpu.current_stream()
|
||||
else:
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
@@ -1019,21 +1035,17 @@ def get_offload_stream(device):
|
||||
|
||||
if device in STREAMS:
|
||||
ss = STREAMS[device]
|
||||
s = ss[stream_counter]
|
||||
#Sync the oldest stream in the queue with the current
|
||||
ss[stream_counter].wait_stream(current_stream(device))
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
elif is_device_xpu(device):
|
||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return ss[stream_counter]
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_xpu(device):
|
||||
@@ -1042,18 +1054,14 @@ def get_offload_stream(device):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
def sync_stream(device, stream):
|
||||
if stream is None:
|
||||
if stream is None or current_stream(device) is None:
|
||||
return
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.current_stream().wait_stream(stream)
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
@@ -1078,6 +1086,79 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||
else:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if type(tensor) is not torch.nn.parameter.Parameter:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
if tensor.is_pinned():
|
||||
#NOTE: Cuda does detect when a tensor is already pinned and would
|
||||
#error below, but there are proven cases where this also queues an error
|
||||
#on the GPU async. So dont trust the CUDA API and guard here
|
||||
return False
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
|
||||
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||
if size_stored is None:
|
||||
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
|
||||
return False
|
||||
|
||||
if size != size_stored:
|
||||
logging.warning("Size of pinned tensor changed")
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
@@ -1330,7 +1411,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||
if manual_cast:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -238,6 +238,7 @@ class ModelPatcher:
|
||||
self.force_cast_weights = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.parent = None
|
||||
self.pinned = set()
|
||||
|
||||
self.attachments: dict[str] = {}
|
||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||
@@ -275,6 +276,9 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
@@ -294,6 +298,7 @@ class ModelPatcher:
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
n.pinned = self.pinned
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
@@ -450,6 +455,19 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
rope_options["scale_y"] = scale_y
|
||||
rope_options["scale_t"] = scale_t
|
||||
|
||||
rope_options["shift_x"] = shift_x
|
||||
rope_options["shift_y"] = shift_y
|
||||
rope_options["shift_t"] = shift_t
|
||||
|
||||
self.model_options["transformer_options"]["rope_options"] = rope_options
|
||||
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
@@ -618,6 +636,21 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if comfy.model_management.pin_memory(weight):
|
||||
self.pinned.add(key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
if key in self.pinned:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
comfy.model_management.unpin_memory(weight)
|
||||
self.pinned.remove(key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
@@ -639,9 +672,11 @@ class ModelPatcher:
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
lowvram_mem_counter = 0
|
||||
loading = self._load_list()
|
||||
|
||||
load_completely = []
|
||||
offloaded = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
@@ -658,6 +693,7 @@ class ModelPatcher:
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
@@ -683,6 +719,7 @@ class ModelPatcher:
|
||||
patch_counter += 1
|
||||
|
||||
cast_weight = True
|
||||
offloaded.append((module_mem, n, m, params))
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
@@ -713,7 +750,9 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
@@ -721,11 +760,17 @@ class ModelPatcher:
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
for x in offloaded:
|
||||
n = x[1]
|
||||
params = x[3]
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
@@ -762,6 +807,7 @@ class ModelPatcher:
|
||||
self.eject_model()
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
@@ -797,7 +843,7 @@ class ModelPatcher:
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
@@ -841,13 +887,19 @@ class ModelPatcher:
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
@@ -857,9 +909,13 @@ class ModelPatcher:
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
@@ -872,6 +928,9 @@ class ModelPatcher:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
@@ -1259,5 +1318,6 @@ class ModelPatcher:
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
|
||||
91
comfy/nested_tensor.py
Normal file
91
comfy/nested_tensor.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors):
|
||||
self.tensors = list(tensors)
|
||||
self.is_nested = True
|
||||
|
||||
def _copy(self):
|
||||
return NestedTensor(self.tensors)
|
||||
|
||||
def apply_operation(self, other, operation):
|
||||
o = self._copy()
|
||||
if isinstance(other, NestedTensor):
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other.tensors[i])
|
||||
else:
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other)
|
||||
return o
|
||||
|
||||
def __add__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x + y)
|
||||
|
||||
def __sub__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x - y)
|
||||
|
||||
def __mul__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x * y)
|
||||
|
||||
# def __itruediv__(self, b):
|
||||
# return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __truediv__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
|
||||
|
||||
def unbind(self):
|
||||
return self.tensors
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
o = self._copy()
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = t.to(*args, **kwargs)
|
||||
return o
|
||||
|
||||
def new_ones(self, *args, **kwargs):
|
||||
return self.tensors[0].new_ones(*args, **kwargs)
|
||||
|
||||
def float(self):
|
||||
return self.to(dtype=torch.float)
|
||||
|
||||
def chunk(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
|
||||
|
||||
def size(self):
|
||||
return self.tensors[0].size()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.tensors[0].shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
dims = 0
|
||||
for t in self.tensors:
|
||||
dims = max(t.ndim, dims)
|
||||
return dims
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.tensors[0].device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
return self.tensors[0].layout
|
||||
|
||||
|
||||
def cat_nested(tensors, *args, **kwargs):
|
||||
cated_tensors = []
|
||||
for i in range(len(tensors[0].tensors)):
|
||||
tens = []
|
||||
for j in range(len(tensors)):
|
||||
tens.append(tensors[j].tensors[i])
|
||||
cated_tensors.append(torch.cat(tens, *args, **kwargs))
|
||||
return NestedTensor(cated_tensors)
|
||||
331
comfy/ops.py
331
comfy/ops.py
@@ -24,13 +24,18 @@ import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
return
|
||||
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
@@ -50,49 +55,90 @@ try:
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
||||
try:
|
||||
if comfy.model_management.is_nvidia():
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
#TODO: change upper bound version once it's fixed'
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
||||
logging.info("working around nvidia conv3d memory bug.")
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if isinstance(input, QuantizedTensor):
|
||||
dtype = input._layout_params["orig_dtype"]
|
||||
else:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
|
||||
if has_function:
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
|
||||
if bias_has_function:
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
if has_function:
|
||||
if weight_has_function or weight.dtype != dtype:
|
||||
with wf_context:
|
||||
weight = weight.to(dtype=dtype)
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
return weight, bias
|
||||
if offloadable:
|
||||
return weight, bias, offload_stream
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
if weight is not None:
|
||||
device = weight.device
|
||||
else:
|
||||
if bias is None:
|
||||
return
|
||||
device = bias.device
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
@@ -105,10 +151,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -119,10 +168,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -133,10 +185,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -146,11 +201,23 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
||||
return out
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -161,10 +228,13 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -176,13 +246,17 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
offload_stream = None
|
||||
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -195,13 +269,18 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -217,12 +296,15 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -238,12 +320,15 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose1d(
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose1d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -258,10 +343,14 @@ class disable_weight_init:
|
||||
output_dtype = out_dtype
|
||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||
out_dtype = None
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
||||
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
@@ -312,20 +401,18 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
Legacy FP8 linear function for backward compatibility.
|
||||
Uses QuantizedTensor subclass for dispatch.
|
||||
"""
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
input_shape = input.shape
|
||||
input_dtype = input.dtype
|
||||
if len(input.shape) == 3:
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||
w = w.t()
|
||||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
@@ -337,23 +424,20 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input_shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
return o
|
||||
|
||||
return None
|
||||
|
||||
@@ -373,8 +457,10 @@ class fp8_ops(manual_cast):
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
@@ -402,12 +488,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
@@ -446,7 +534,120 @@ if CUBLAS_IS_AVAILABLE:
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {}
|
||||
_compute_dtype = torch.bfloat16
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tensor_class = None
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
@@ -150,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict):
|
||||
curr_value = merged_dict.setdefault(key, {})
|
||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
||||
merged_dict[key] = merge_nested_dicts(curr_value, value)
|
||||
elif isinstance(value, list):
|
||||
merged_dict.setdefault(key, []).extend(value)
|
||||
else:
|
||||
|
||||
545
comfy/quant_ops.py
Normal file
545
comfy/quant_ops.py
Normal file
@@ -0,0 +1,545 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
"""
|
||||
Decorator to register a layout-specific operation handler.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
Example:
|
||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||
def fp8_linear(func, args, kwargs):
|
||||
# FP8-specific linear implementation
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
if torch_op not in _LAYOUT_REGISTRY:
|
||||
_LAYOUT_REGISTRY[torch_op] = {}
|
||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_generic_util(torch_op):
|
||||
"""
|
||||
Decorator to register a generic utility that works for all layouts.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||
|
||||
Example:
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
# Works for any layout
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
_GENERIC_UTILS[torch_op] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_layout_from_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg._layout_type
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for item in arg:
|
||||
if isinstance(item, QuantizedTensor):
|
||||
return item._layout_type
|
||||
return None
|
||||
|
||||
|
||||
def _move_layout_params_to_device(params, device):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.to(device=device)
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def _copy_layout_params(params):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.clone()
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||
for k, v in src.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
dst[k].copy_(v, non_blocking=non_blocking)
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Base class for quantization layouts.
|
||||
|
||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||
|
||||
New quantization formats should subclass this and implement the required methods.
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||
|
||||
|
||||
class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Universal quantized tensor that works with any layout.
|
||||
|
||||
This tensor subclass uses a pluggable layout system to support multiple
|
||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||
|
||||
The layout_type determines format-specific behavior, while common operations
|
||||
(detach, clone, to) are handled generically.
|
||||
|
||||
Attributes:
|
||||
_qdata: The quantized tensor data
|
||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
@property
|
||||
def layout_type(self):
|
||||
return self._layout_type
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
"""
|
||||
inner_tensors = ["_qdata"]
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
|
||||
tensor_params = {}
|
||||
non_tensor_params = {}
|
||||
for k, v in self._layout_params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tensor_params[k] = v
|
||||
else:
|
||||
non_tensor_params[k] = v
|
||||
|
||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||
ctx["non_tensor_params"] = non_tensor_params
|
||||
|
||||
for k, v in tensor_params.items():
|
||||
attr_name = f"_layout_param_{k}"
|
||||
object.__setattr__(self, attr_name, v)
|
||||
inner_tensors.append(attr_name)
|
||||
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||
"""
|
||||
Tensor unflattening protocol for proper device movement.
|
||||
Reconstructs the QuantizedTensor after device movement.
|
||||
"""
|
||||
layout_type = ctx["layout_type"]
|
||||
layout_params = dict(ctx["non_tensor_params"])
|
||||
|
||||
for key in ctx["tensor_param_keys"]:
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||
return cls(qdata, layout_type, layout_params)
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||
if func in _GENERIC_UTILS:
|
||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||
|
||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||
layout_type = _get_layout_from_args(args)
|
||||
if layout_type and func in _LAYOUT_REGISTRY:
|
||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||
def dequant_arg(arg):
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg.dequantize()
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return type(arg)(dequant_arg(a) for a in arg)
|
||||
return arg
|
||||
|
||||
new_args = dequant_arg(args)
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# ==============================================================================
|
||||
|
||||
def _create_transformed_qtensor(qt, transform_fn):
|
||||
new_data = transform_fn(qt._qdata)
|
||||
new_params = _copy_layout_params(qt._layout_params)
|
||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_dtype is not None and target_dtype != qt.dtype:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
f"but not supported. Ignoring layout."
|
||||
)
|
||||
|
||||
# Handle device transfer
|
||||
current_device = qt._qdata.device
|
||||
if target_device is not None:
|
||||
# Normalize device for comparison
|
||||
if isinstance(target_device, str):
|
||||
target_device = torch.device(target_device)
|
||||
if isinstance(current_device, str):
|
||||
current_device = torch.device(current_device)
|
||||
|
||||
if target_device != current_device:
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||
return qt
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.clone.default)
|
||||
def generic_clone(func, args, kwargs):
|
||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||
def generic_to_copy(func, args, kwargs):
|
||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
op_name="_to_copy"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||
def generic_to_dtype_layout(func, args, kwargs):
|
||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
target_layout=kwargs.get('layout', None),
|
||||
op_name="to"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.copy_.default)
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
non_blocking = args[2] if len(args) > 2 else False
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
return qt_dest
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||
def generic_empty_like(func, args, kwargs):
|
||||
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
# Create empty tensor with same shape and dtype as the quantized data
|
||||
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||
|
||||
# Handle device transfer for layout params
|
||||
target_device = kwargs.get('device', new_qdata.device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
|
||||
# Update orig_dtype if dtype is specified
|
||||
new_params['orig_dtype'] = hp_dtype
|
||||
|
||||
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layout + Operation Handlers
|
||||
# ==============================================================================
|
||||
class TensorCoreFP8Layout(QuantizedLayout):
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if scale is None:
|
||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||
# lp_amax = torch.finfo(dtype).max
|
||||
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return qdata, layout_params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
return plain_tensor * scale
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
return qtensor._qdata, qtensor._layout_params['scale']
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||
},
|
||||
}
|
||||
|
||||
LAYOUTS = {
|
||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||
}
|
||||
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||
def fp8_linear(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
out_dtype = kwargs.get("out_dtype")
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
weight_t = plain_weight.t()
|
||||
|
||||
tensor_2d = False
|
||||
if len(plain_input.shape) == 2:
|
||||
tensor_2d = True
|
||||
plain_input = plain_input.unsqueeze(1)
|
||||
|
||||
input_shape = plain_input.shape
|
||||
if len(input_shape) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
output = torch._scaled_mm(
|
||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
|
||||
if not tensor_2d:
|
||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||
|
||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
output_scale = scale_a * scale_b
|
||||
output_params = {
|
||||
'scale': output_scale,
|
||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||
}
|
||||
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||
else:
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
|
||||
# Case 2: DQ Fallback
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
input_tensor = input_tensor.dequantize()
|
||||
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
plain_input.contiguous(),
|
||||
plain_weight,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||
def fp8_addmm(func, args, kwargs):
|
||||
input_tensor = args[1]
|
||||
weight = args[2]
|
||||
bias = args[0]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
if isinstance(args[2], QuantizedTensor):
|
||||
a[2] = args[2].dequantize()
|
||||
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||
def fp8_mm(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||
def fp8_func(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
ar = list(args)
|
||||
ar[0] = plain_input
|
||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||
return func(*args, **kwargs)
|
||||
@@ -4,13 +4,9 @@ import comfy.samplers
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
import logging
|
||||
import comfy.nested_tensor
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
@@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return torch.cat(noises, axis=0)
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
if latent_image.is_nested:
|
||||
tensors = latent_image.unbind()
|
||||
noises = []
|
||||
for t in tensors:
|
||||
noises.append(prepare_noise_inner(t, generator, noise_inds))
|
||||
noises = comfy.nested_tensor.NestedTensor(noises)
|
||||
else:
|
||||
noises = prepare_noise_inner(latent_image, generator, noise_inds)
|
||||
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
|
||||
@@ -306,17 +306,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
# TODO: replace with merge_nested_dicts function
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
transformer_options["patches"] = cur_patches
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
@@ -789,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||
@@ -799,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for k in conds:
|
||||
@@ -969,11 +962,11 @@ class CFGGuider:
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
|
||||
|
||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||
@@ -987,7 +980,7 @@ class CFGGuider:
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
@@ -1001,7 +994,7 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
@@ -1014,6 +1007,12 @@ class CFGGuider:
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
if latent_image.is_nested:
|
||||
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
|
||||
noise, _ = comfy.utils.pack_latents(noise.unbind())
|
||||
else:
|
||||
latent_shapes = [latent_image.shape]
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
@@ -1033,7 +1032,7 @@ class CFGGuider:
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||
self.model_options = orig_model_options
|
||||
@@ -1041,6 +1040,9 @@ class CFGGuider:
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
|
||||
if len(latent_shapes) > 1:
|
||||
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
60
comfy/sd.py
60
comfy/sd.py
@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
import yaml
|
||||
import math
|
||||
@@ -142,6 +143,9 @@ class CLIP:
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.patcher.get_ram_usage()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@@ -275,8 +279,13 @@ class VAE:
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
else:
|
||||
VAE_KL_MEM_RATIO = 1.0
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
@@ -287,10 +296,12 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
self.size = None
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@@ -542,6 +553,25 @@ class VAE:
|
||||
self.latent_channels = 3
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
|
||||
sample_rate = 16000
|
||||
if sample_rate == 16000:
|
||||
mode = '16k'
|
||||
else:
|
||||
mode = '44k'
|
||||
|
||||
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
|
||||
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 20
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.downscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -569,12 +599,25 @@ class VAE:
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
def model_size(self):
|
||||
if self.size is not None:
|
||||
return self.size
|
||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
if not self.crop_input:
|
||||
return pixels
|
||||
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
@@ -1233,7 +1276,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
@@ -1267,7 +1310,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
@@ -1301,7 +1344,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
if model_config.layer_quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
@@ -1317,8 +1363,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
|
||||
@@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
@@ -468,6 +468,7 @@ class SDTokenizer:
|
||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
self.pad_left = pad_left
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
@@ -522,6 +523,12 @@ class SDTokenizer:
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
def pad_tokens(self, tokens, amount):
|
||||
if self.pad_left:
|
||||
for i in range(amount):
|
||||
tokens.insert(0, (self.pad_token, 1.0, 0))
|
||||
else:
|
||||
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||
'''
|
||||
@@ -600,7 +607,7 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
self.pad_tokens(batch, remaining_length)
|
||||
#start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
@@ -614,11 +621,11 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if min_padding is not None:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||
self.pad_tokens(batch, min_padding)
|
||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
self.pad_tokens(batch, self.max_length - len(batch))
|
||||
if min_length is not None and len(batch) < min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||
self.pad_tokens(batch, min_length - len(batch))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
|
||||
@@ -50,6 +50,7 @@ class BASE:
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -32,6 +32,7 @@ class Llama2Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@@ -53,6 +54,7 @@ class Qwen25_3BConfig:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@@ -96,6 +99,7 @@ class Gemma2_2B_Config:
|
||||
k_norm = None
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@@ -118,6 +122,7 @@ class Gemma3_4B_Config:
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [False, False, False, False, False, 1024]
|
||||
rope_scale = [1.0, 8.0]
|
||||
final_norm: bool = True
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@@ -366,7 +371,12 @@ class Llama2_(nn.Module):
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
@@ -421,14 +431,16 @@ class Llama2_(nn.Module):
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
x = self.norm(x)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
@@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
from numpy.core.multiarray import scalar
|
||||
def scalar(*args, **kwargs):
|
||||
from numpy.core.multiarray import scalar as sc
|
||||
return sc(*args, **kwargs)
|
||||
scalar.__module__ = "numpy.core.multiarray"
|
||||
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
from _codecs import encode
|
||||
@@ -1102,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
||||
def pack_latents(latents):
|
||||
latent_shapes = []
|
||||
tensors = []
|
||||
for tensor in latents:
|
||||
latent_shapes.append(tensor.shape)
|
||||
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
|
||||
|
||||
latent = torch.cat(tensors, dim=-1)
|
||||
return latent, latent_shapes
|
||||
|
||||
def unpack_latents(combined_latent, latent_shapes):
|
||||
if len(latent_shapes) > 1:
|
||||
output_tensors = []
|
||||
for shape in latent_shapes:
|
||||
cut = math.prod(shape[1:])
|
||||
tens = combined_latent[:, :, :cut]
|
||||
combined_latent = combined_latent[:, :, cut:]
|
||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
@@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io as io
|
||||
from . import _ui as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
@@ -104,6 +104,8 @@ class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
@@ -114,7 +116,9 @@ if TYPE_CHECKING:
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
comfy_io = io # create the new alias for io
|
||||
# create new aliases for io and ui
|
||||
IO = io
|
||||
UI = ui
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
@@ -124,6 +128,7 @@ __all__ = [
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"comfy_io",
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
]
|
||||
|
||||
@@ -27,6 +27,7 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
||||
|
||||
@@ -656,11 +657,11 @@ class LossMap(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="VOXEL")
|
||||
class Voxel(ComfyTypeIO):
|
||||
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = VOXEL
|
||||
|
||||
@comfytype(io_type="MESH")
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = MESH
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
]
|
||||
|
||||
12
comfy_api/latest/_util/geometry_types.py
Normal file
12
comfy_api/latest/_util/geometry_types.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data: torch.Tensor):
|
||||
self.data = data
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
@@ -1,704 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiClient,
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
UploadRequest,
|
||||
UploadResponse,
|
||||
)
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
async def download_url_to_video_output(
|
||||
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
video_url: The URL of the video to download.
|
||||
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return VideoFromFile(video_io)
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
async def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
response: The response to validate and cast.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
"""
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise ValueError("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
for img_data in data:
|
||||
img_bytes: bytes
|
||||
if img_data.b64_json:
|
||||
img_bytes = base64.b64decode(img_data.b64_json)
|
||||
elif img_data.url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||
async with session.get(img_data.url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("Failed to download generated image")
|
||||
img_bytes = await resp.read()
|
||||
else:
|
||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
def validate_aspect_ratio(
|
||||
aspect_ratio: str,
|
||||
minimum_ratio: float,
|
||||
maximum_ratio: float,
|
||||
minimum_ratio_str: str,
|
||||
maximum_ratio_str: str,
|
||||
) -> float:
|
||||
"""Validates and casts an aspect ratio string to a float.
|
||||
|
||||
Args:
|
||||
aspect_ratio: The aspect ratio string to validate.
|
||||
minimum_ratio: The minimum aspect ratio.
|
||||
maximum_ratio: The maximum aspect ratio.
|
||||
minimum_ratio_str: The minimum aspect ratio string.
|
||||
maximum_ratio_str: The maximum aspect ratio string.
|
||||
|
||||
Returns:
|
||||
The validated and cast aspect ratio.
|
||||
|
||||
Raises:
|
||||
Exception: If the aspect ratio is not valid.
|
||||
"""
|
||||
# get ratio values
|
||||
numbers = aspect_ratio.split(":")
|
||||
if len(numbers) != 2:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
||||
)
|
||||
try:
|
||||
numerator = int(numbers[0])
|
||||
denominator = int(numbers[1])
|
||||
except ValueError as exc:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
||||
) from exc
|
||||
calculated_ratio = numerator / denominator
|
||||
# if not close to minimum and maximum, check bounds
|
||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
||||
calculated_ratio, maximum_ratio
|
||||
):
|
||||
if calculated_ratio < minimum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
if calculated_ratio > maximum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
return aspect_ratio
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
url: The URL to download.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
headers = {}
|
||||
if url.startswith("/proxy/"):
|
||||
url = str(args.comfy_api_base).rstrip("/") + url
|
||||
auth_token = auth_kwargs.get("auth_token")
|
||||
comfy_api_key = auth_kwargs.get("comfy_api_key")
|
||||
if auth_token:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
elif comfy_api_key:
|
||||
headers["X-API-KEY"] = comfy_api_key
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(await resp.read())
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(
|
||||
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
||||
).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = (
|
||||
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
)
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
|
||||
Args:
|
||||
file_bytes_io: BytesIO object containing the file data.
|
||||
filename: The filename of the file.
|
||||
upload_mime_type: MIME type of the file.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = await operation.execute()
|
||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
||||
return response.download_url
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: VideoInput,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = io.BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded video file.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.duration_seconds
|
||||
if actual_duration is not None and actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error getting video duration: %s", str(e))
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = io.BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
max_images: Maximum number of images to upload.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
def resize_mask_to_image(
|
||||
mask: torch.Tensor,
|
||||
image: torch.Tensor,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled",
|
||||
allow_gradient=True,
|
||||
add_channel_dim=False,
|
||||
):
|
||||
"""
|
||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||
"""
|
||||
_, H, W, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1, 1)
|
||||
mask = common_upscale(
|
||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
||||
)
|
||||
mask = mask.movedim(1, -1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(
|
||||
image1: torch.Tensor, image2: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
@@ -1,17 +0,0 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import PixverseDto
|
||||
|
||||
|
||||
class ResponseData(BaseModel):
|
||||
ErrCode: Optional[int] = None
|
||||
ErrMsg: Optional[str] = None
|
||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
||||
@@ -1,57 +0,0 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class V2OpenAPII2VResp(BaseModel):
|
||||
video_id: Optional[int] = Field(None, description='Video_id')
|
||||
|
||||
|
||||
class V2OpenAPIT2VReq(BaseModel):
|
||||
aspect_ratio: str = Field(
|
||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
||||
)
|
||||
duration: int = Field(
|
||||
...,
|
||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
||||
examples=[5],
|
||||
)
|
||||
model: str = Field(
|
||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
||||
)
|
||||
motion_mode: Optional[str] = Field(
|
||||
'normal',
|
||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||
examples=['normal'],
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative prompt\n', max_length=2048
|
||||
)
|
||||
prompt: str = Field(..., description='Prompt', max_length=2048)
|
||||
quality: str = Field(
|
||||
...,
|
||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||
examples=['540p'],
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
||||
style: Optional[str] = Field(
|
||||
None,
|
||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
||||
examples=['anime'],
|
||||
)
|
||||
template_id: Optional[int] = Field(
|
||||
None,
|
||||
description='Template ID (template_id must be activated before use)',
|
||||
examples=[302325299692608],
|
||||
)
|
||||
water_mark: Optional[bool] = Field(
|
||||
False,
|
||||
description='Watermark (true: add watermark, false: no watermark)',
|
||||
examples=[False],
|
||||
)
|
||||
@@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
|
||||
|
||||
class BFLFluxCannyImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
||||
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxDepthImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
@@ -160,15 +122,8 @@ class BFLStatus(str, Enum):
|
||||
error = "Error"
|
||||
|
||||
|
||||
class BFLFluxProStatusResponse(BaseModel):
|
||||
class BFLFluxStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(
|
||||
None, description="The result of the task (null if not completed)."
|
||||
)
|
||||
progress: confloat(ge=0.0, le=1.0) = Field(
|
||||
..., description="The progress of the task (0.0 to 1.0)."
|
||||
)
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional details about the task (null if not available)."
|
||||
)
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||
|
||||
@@ -1,972 +0,0 @@
|
||||
"""
|
||||
API Client Framework for api.comfy.org.
|
||||
|
||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
||||
|
||||
Key Components:
|
||||
--------------
|
||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||||
3. ApiOperation - Executes a single synchronous API operation
|
||||
|
||||
Usage Examples:
|
||||
--------------
|
||||
|
||||
# Example 1: Synchronous API Operation
|
||||
# ------------------------------------
|
||||
# For a simple API call that returns the result immediately:
|
||||
|
||||
# 1. Create the API client
|
||||
api_client = ApiClient(
|
||||
base_url="https://api.example.com",
|
||||
auth_token="your_auth_token_here",
|
||||
comfy_api_key="your_comfy_api_key_here",
|
||||
timeout=30.0,
|
||||
verify_ssl=True
|
||||
)
|
||||
|
||||
# 2. Define the endpoint
|
||||
user_info_endpoint = ApiEndpoint(
|
||||
path="/v1/users/me",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest, # No request body needed
|
||||
response_model=UserProfile, # Pydantic model for the response
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 3. Create the request object
|
||||
request = EmptyRequest()
|
||||
|
||||
# 4. Create and execute the operation
|
||||
operation = ApiOperation(
|
||||
endpoint=user_info_endpoint,
|
||||
request=request
|
||||
)
|
||||
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
|
||||
|
||||
|
||||
# Example 2: Asynchronous API Operation with Polling
|
||||
# -------------------------------------------------
|
||||
# For an API that starts a task and requires polling for completion:
|
||||
|
||||
# 1. Define the endpoints (initial request and polling)
|
||||
generate_image_endpoint = ApiEndpoint(
|
||||
path="/v1/images/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=ImageGenerationRequest,
|
||||
response_model=TaskCreatedResponse,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
check_task_endpoint = ApiEndpoint(
|
||||
path="/v1/tasks/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=ImageGenerationResult,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 2. Create the request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt="a beautiful sunset over mountains",
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# 3. Create and execute the polling operation
|
||||
operation = PollingOperation(
|
||||
initial_endpoint=generate_image_endpoint,
|
||||
initial_request=request,
|
||||
poll_endpoint=check_task_endpoint,
|
||||
task_id_field="task_id",
|
||||
status_field="status",
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed", "error"]
|
||||
)
|
||||
|
||||
# This will make the initial request and then poll until completion
|
||||
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
||||
from typing import Type, Optional, Any, TypeVar, Generic, Callable
|
||||
from enum import Enum
|
||||
import json
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid # For generating unique operation IDs
|
||||
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
from comfy import utils
|
||||
from . import request_logger
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
P = TypeVar("P", bound=BaseModel) # For poll response
|
||||
|
||||
PROGRESS_BAR_MAX = 100
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
pass
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
pass
|
||||
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Base class for empty request bodies.
|
||||
For GET requests, fields will be sent as query parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||
|
||||
|
||||
class HttpMethod(str, Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
"""
|
||||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
timeout: float = 3600.0,
|
||||
verify_ssl: bool = True,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
retry_status_codes: Optional[tuple[int, ...]] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
||||
# 500, 502, 503, 504 (Server Errors)
|
||||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
||||
self._session: Optional[aiohttp.ClientSession] = session
|
||||
self._owns_session = session is None # Track if we have to close it
|
||||
|
||||
@staticmethod
|
||||
def _generate_operation_id(path: str) -> str:
|
||||
"""Generates a unique operation ID for logging."""
|
||||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@staticmethod
|
||||
def _create_json_payload_args(
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"json": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def _create_form_data_args(
|
||||
self,
|
||||
data: dict[str, Any] | None,
|
||||
files: dict[str, Any] | None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
multipart_parser: Callable | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if headers and "Content-Type" in headers:
|
||||
del headers["Content-Type"]
|
||||
|
||||
if multipart_parser and data:
|
||||
data = multipart_parser(data)
|
||||
|
||||
if isinstance(data, aiohttp.FormData):
|
||||
form = data # If the parser already returned a FormData, pass it through
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if data: # regular text fields
|
||||
for k, v in data.items():
|
||||
if v is None:
|
||||
continue # aiohttp fails to serialize "None" values
|
||||
# aiohttp expects strings or bytes; convert enums etc.
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
|
||||
if files:
|
||||
file_iter = files if isinstance(files, list) else files.items()
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue # aiohttp fails to serialize "None" values
|
||||
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
|
||||
if isinstance(file_obj, tuple):
|
||||
filename, file_value, content_type = self._unpack_tuple(file_obj)
|
||||
else:
|
||||
file_value = file_obj
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
form.add_field(
|
||||
name=field_name,
|
||||
value=file_value,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
return {"data": form, "headers": headers or {}}
|
||||
|
||||
@staticmethod
|
||||
def _create_urlencoded_form_data_args(
|
||||
data: dict[str, Any],
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
headers = headers or {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
return {
|
||||
"data": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for API requests, including authentication if available"""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
elif self.comfy_api_key:
|
||||
headers["X-API-KEY"] = self.comfy_api_key
|
||||
|
||||
return headers
|
||||
|
||||
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
|
||||
"""
|
||||
Check connectivity to determine if network issues are local or server-related.
|
||||
|
||||
Args:
|
||||
target_url: URL to check connectivity to
|
||||
|
||||
Returns:
|
||||
Dictionary with connectivity status details
|
||||
"""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
"is_local_issue": False,
|
||||
"is_api_issue": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
try:
|
||||
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
except (ClientError, asyncio.TimeoutError, socket.gaierror):
|
||||
results["is_local_issue"] = True
|
||||
return results # cannot reach the internet – early exit
|
||||
|
||||
# Now check API health endpoint
|
||||
parsed = urlparse(target_url)
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
try:
|
||||
async with session.get(health_url, ssl=self.verify_ssl) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
except ClientError:
|
||||
pass # leave as False
|
||||
|
||||
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
|
||||
return results
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable | None = None,
|
||||
retry_count: int = 0, # Used internally for tracking retries
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the API with automatic retries for transient errors.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: API endpoint path (will be joined with base_url)
|
||||
params: Query parameters
|
||||
data: body data
|
||||
files: Files to upload
|
||||
headers: Additional headers
|
||||
content_type: Content type of the request. Defaults to application/json.
|
||||
retry_count: Internal parameter for tracking retries, do not set manually
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
LocalNetworkError: If local network connectivity issues are detected
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
|
||||
# Build full URL and merge headers
|
||||
relative_path = path.lstrip("/")
|
||||
url = urljoin(self.base_url, relative_path)
|
||||
self._check_auth(self.auth_token, self.comfy_api_key)
|
||||
|
||||
request_headers = self.get_headers()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
if files:
|
||||
request_headers.pop("Content-Type", None)
|
||||
if params:
|
||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
||||
|
||||
logging.debug("[DEBUG] Request Headers: %s", request_headers)
|
||||
logging.debug("[DEBUG] Files: %s", files)
|
||||
logging.debug("[DEBUG] Params: %s", params)
|
||||
logging.debug("[DEBUG] Data: %s", data)
|
||||
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
||||
elif content_type == "multipart/form-data":
|
||||
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
|
||||
else:
|
||||
payload_args = self._create_json_payload_args(data, request_headers)
|
||||
|
||||
operation_id = self._generate_operation_id(path)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=request_headers,
|
||||
request_params=params,
|
||||
request_data=data if content_type == "application/json" else "[form-data or other]",
|
||||
)
|
||||
|
||||
session = await self._get_session()
|
||||
try:
|
||||
async with session.request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
ssl=self.verify_ssl,
|
||||
**payload_args,
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
error_data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
error_data = await resp.text()
|
||||
|
||||
return await self._handle_http_error(
|
||||
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
|
||||
operation_id,
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
data,
|
||||
files,
|
||||
headers,
|
||||
content_type,
|
||||
multipart_parser,
|
||||
retry_count=retry_count,
|
||||
response_content=error_data,
|
||||
)
|
||||
|
||||
# Success – parse JSON (safely) and log
|
||||
try:
|
||||
payload = await resp.json()
|
||||
response_content_to_log = payload
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
payload = {}
|
||||
response_content_to_log = await resp.text()
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
return payload
|
||||
|
||||
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
|
||||
# Treat as *connection* problem – optionally retry, else escalate
|
||||
if retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
|
||||
self.max_retries, str(e))
|
||||
await asyncio.sleep(delay)
|
||||
return await self.request(
|
||||
method,
|
||||
path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
# One final connectivity check for diagnostics
|
||||
connectivity = await self._check_connectivity(self.base_url)
|
||||
if connectivity["is_local_issue"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError(
|
||||
f"The API server at {self.base_url} is currently unreachable. "
|
||||
f"The service may be experiencing issues. Please try again later."
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _check_auth(auth_token, comfy_api_key):
|
||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
||||
if auth_token is None and comfy_api_key is None:
|
||||
raise Exception("Unauthorized: Please login first to use this node.")
|
||||
return auth_token or comfy_api_key
|
||||
|
||||
@staticmethod
|
||||
async def upload_file(
|
||||
upload_url: str,
|
||||
file: io.BytesIO | str,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""Upload a file to the API with retry logic.
|
||||
|
||||
Args:
|
||||
upload_url: The URL to upload to
|
||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||||
content_type: Optional mime type to set for the upload
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
retry_backoff_factor: Multiplier for the delay after each retry
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
skip_auto_headers: set[str] = set()
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
else:
|
||||
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
|
||||
skip_auto_headers.add("Content-Type")
|
||||
|
||||
# Extract file bytes
|
||||
if isinstance(file, io.BytesIO):
|
||||
file.seek(0)
|
||||
data = file.read()
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
raise ValueError("File must be BytesIO or str path")
|
||||
|
||||
parsed = urlparse(upload_url)
|
||||
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
|
||||
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
|
||||
delay = retry_delay
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.put(
|
||||
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
return resp
|
||||
except (ClientError, asyncio.TimeoutError) as e:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=e.status if hasattr(e, "status") else None,
|
||||
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
|
||||
response_content=None,
|
||||
error_message=f"{type(e).__name__}: {str(e)}",
|
||||
)
|
||||
if attempt < max_retries:
|
||||
logging.warning(
|
||||
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay *= retry_backoff_factor
|
||||
else:
|
||||
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
|
||||
|
||||
async def _handle_http_error(
|
||||
self,
|
||||
exc: ClientResponseError,
|
||||
operation_id: str,
|
||||
*req_meta,
|
||||
retry_count: int,
|
||||
response_content: dict | str = "",
|
||||
) -> dict[str, Any]:
|
||||
status_code = exc.status
|
||||
if status_code == 401:
|
||||
user_friendly = "Unauthorized: Please login first to use this node."
|
||||
elif status_code == 402:
|
||||
user_friendly = "Payment Required: Please add credits to your account to use this node."
|
||||
elif status_code == 409:
|
||||
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
|
||||
elif status_code == 429:
|
||||
user_friendly = "Rate Limit Exceeded: Please try again later."
|
||||
else:
|
||||
if isinstance(response_content, dict):
|
||||
if "error" in response_content and "message" in response_content["error"]:
|
||||
user_friendly = f"API Error: {response_content['error']['message']}"
|
||||
if "type" in response_content["error"]:
|
||||
user_friendly += f" (Type: {response_content['error']['type']})"
|
||||
else: # Handle cases where error is just a JSON dict with unknown format
|
||||
user_friendly = f"API Error: {json.dumps(response_content)}"
|
||||
else:
|
||||
if len(response_content) < 200: # Arbitrary limit for display
|
||||
user_friendly = f"API Error (raw): {response_content}"
|
||||
else:
|
||||
user_friendly = f"API Error (raw, status {response_content})"
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=req_meta[0],
|
||||
request_url=req_meta[1],
|
||||
response_status_code=exc.status,
|
||||
response_headers=dict(req_meta[5]) if req_meta[5] else None,
|
||||
response_content=response_content,
|
||||
error_message=f"HTTP Error {exc.status}",
|
||||
)
|
||||
|
||||
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
|
||||
if response_content:
|
||||
logging.debug("[DEBUG] Response content: %s", response_content)
|
||||
|
||||
# Retry if eligible
|
||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
"HTTP error %s. Retrying in %.2fs (%s/%s)",
|
||||
status_code,
|
||||
delay,
|
||||
retry_count + 1,
|
||||
self.max_retries,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
return await self.request(
|
||||
req_meta[0], # method
|
||||
req_meta[1].replace(self.base_url, ""), # path
|
||||
params=req_meta[2],
|
||||
data=req_meta[3],
|
||||
files=req_meta[4],
|
||||
headers=req_meta[5],
|
||||
content_type=req_meta[6],
|
||||
multipart_parser=req_meta[7],
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
raise Exception(user_friendly) from exc
|
||||
|
||||
@staticmethod
|
||||
def _unpack_tuple(t):
|
||||
"""Helper to normalise (filename, file, content_type) tuples."""
|
||||
if len(t) == 3:
|
||||
return t
|
||||
elif len(t) == 2:
|
||||
return t[0], t[1], "application/octet-stream"
|
||||
else:
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self._owns_session = True
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_session and self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self) -> "ApiClient":
|
||||
"""Allow usage as async‑context‑manager – ensures clean teardown"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
class ApiEndpoint(Generic[T, R]):
|
||||
"""Defines an API endpoint with its request and response types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: HttpMethod,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
query_params: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize an API endpoint definition.
|
||||
|
||||
Args:
|
||||
path: The URL path for this endpoint, can include placeholders like {id}
|
||||
method: The HTTP method to use (GET, POST, etc.)
|
||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||||
query_params: Optional dictionary of query parameters to include in the request
|
||||
"""
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.request_model = request_model
|
||||
self.response_model = response_model
|
||||
self.query_params = query_params or {}
|
||||
|
||||
|
||||
class SynchronousOperation(Generic[T, R]):
|
||||
"""Represents a single synchronous API operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: ApiEndpoint[T, R],
|
||||
request: T,
|
||||
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
timeout: float = 7200.0,
|
||||
verify_ssl: bool = True,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
self.files = files
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.content_type = content_type
|
||||
self.multipart_parser = multipart_parser
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
|
||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
timeout=self.timeout,
|
||||
verify_ssl=self.verify_ssl,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
|
||||
try:
|
||||
request_dict: Optional[dict[str, Any]]
|
||||
if isinstance(self.request, EmptyRequest):
|
||||
request_dict = None
|
||||
else:
|
||||
request_dict = self.request.model_dump(exclude_none=True)
|
||||
for k, v in list(request_dict.items()):
|
||||
if isinstance(v, Enum):
|
||||
request_dict[k] = v.value
|
||||
|
||||
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
|
||||
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
|
||||
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
|
||||
|
||||
response_json = await client.request(
|
||||
self.endpoint.method.value,
|
||||
self.endpoint.path,
|
||||
params=self.endpoint.query_params,
|
||||
data=request_dict,
|
||||
files=self.files,
|
||||
content_type=self.content_type,
|
||||
multipart_parser=self.multipart_parser,
|
||||
)
|
||||
|
||||
logging.debug("=" * 50)
|
||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
|
||||
logging.debug("=" * 50)
|
||||
|
||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
||||
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
|
||||
return parsed_response
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.close()
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Enum for task status values"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class PollingOperation(Generic[T, R]):
|
||||
"""Represents an asynchronous API operation that requires polling for completion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||||
completed_statuses: list[str],
|
||||
failed_statuses: list[str],
|
||||
status_extractor: Callable[[R], Optional[str]],
|
||||
progress_extractor: Callable[[R], Optional[float]] | None = None,
|
||||
result_url_extractor: Callable[[R], Optional[str]] | None = None,
|
||||
request: Optional[T] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||||
max_retries: int = 3, # Max retries per individual API call
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
estimated_duration: Optional[float] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.poll_endpoint = poll_endpoint
|
||||
self.request = request
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.poll_interval = poll_interval
|
||||
self.max_poll_attempts = max_poll_attempts
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
self.estimated_duration = estimated_duration
|
||||
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
|
||||
self.progress_extractor = progress_extractor
|
||||
self.result_url_extractor = result_url_extractor
|
||||
self.node_id = node_id
|
||||
self.completed_statuses = completed_statuses
|
||||
self.failed_statuses = failed_statuses
|
||||
self.final_response: Optional[R] = None
|
||||
|
||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
try:
|
||||
return await self._poll_until_complete(client)
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.close()
|
||||
|
||||
def _display_text_on_node(self, text: str):
|
||||
if not self.node_id:
|
||||
return
|
||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
||||
|
||||
def _display_time_progress_on_node(self, time_completed: int | float):
|
||||
if not self.node_id:
|
||||
return
|
||||
if self.estimated_duration is not None:
|
||||
remaining = max(0, int(self.estimated_duration) - time_completed)
|
||||
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
|
||||
else:
|
||||
message = f"Task in progress: {time_completed}s"
|
||||
self._display_text_on_node(message)
|
||||
|
||||
def _check_task_status(self, response: R) -> TaskStatus:
|
||||
try:
|
||||
status = self.status_extractor(response)
|
||||
if status in self.completed_statuses:
|
||||
return TaskStatus.COMPLETED
|
||||
if status in self.failed_statuses:
|
||||
return TaskStatus.FAILED
|
||||
return TaskStatus.PENDING
|
||||
except Exception as e:
|
||||
logging.error("Error extracting status: %s", e)
|
||||
return TaskStatus.PENDING
|
||||
|
||||
async def _poll_until_complete(self, client: ApiClient) -> R:
|
||||
"""Poll until the task is complete"""
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
||||
|
||||
if self.progress_extractor:
|
||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||||
|
||||
status = TaskStatus.PENDING
|
||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
||||
try:
|
||||
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
|
||||
|
||||
request_dict = (
|
||||
None if self.request is None else self.request.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
if poll_count == 1:
|
||||
logging.debug(
|
||||
"[DEBUG] Poll Request: %s %s",
|
||||
self.poll_endpoint.method.value,
|
||||
self.poll_endpoint.path,
|
||||
)
|
||||
logging.debug(
|
||||
"[DEBUG] Poll Request Data: %s",
|
||||
json.dumps(request_dict, indent=2) if request_dict else "None",
|
||||
)
|
||||
|
||||
# Query task status
|
||||
resp = await client.request(
|
||||
self.poll_endpoint.method.value,
|
||||
self.poll_endpoint.path,
|
||||
params=self.poll_endpoint.query_params,
|
||||
data=request_dict,
|
||||
)
|
||||
consecutive_errors = 0 # reset on success
|
||||
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
|
||||
|
||||
# Check if task is complete
|
||||
status = self._check_task_status(response_obj)
|
||||
logging.debug("[DEBUG] Task Status: %s", status)
|
||||
|
||||
# If progress extractor is provided, extract progress
|
||||
if self.progress_extractor:
|
||||
new_progress = self.progress_extractor(response_obj)
|
||||
if new_progress is not None:
|
||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||
|
||||
if status == TaskStatus.COMPLETED:
|
||||
message = "Task completed successfully"
|
||||
if self.result_url_extractor:
|
||||
result_url = self.result_url_extractor(response_obj)
|
||||
if result_url:
|
||||
message = f"Result URL: {result_url}"
|
||||
logging.debug("[DEBUG] %s", message)
|
||||
self._display_text_on_node(message)
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
progress.update(100)
|
||||
return self.final_response
|
||||
if status == TaskStatus.FAILED:
|
||||
message = f"Task failed: {json.dumps(resp)}"
|
||||
logging.error("[DEBUG] %s", message)
|
||||
raise Exception(message)
|
||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||
# Task pending – wait
|
||||
for i in range(int(self.poll_interval)):
|
||||
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except (LocalNetworkError, ApiServerError, NetworkError) as e:
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
||||
) from e
|
||||
logging.warning(
|
||||
"Network error (%s/%s): %s",
|
||||
consecutive_errors,
|
||||
max_consecutive_errors,
|
||||
str(e),
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
except Exception as e:
|
||||
# For other errors, increment count and potentially abort
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
logging.error("[DEBUG] Polling error: %s", str(e))
|
||||
logging.warning(
|
||||
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
|
||||
poll_count,
|
||||
self.max_poll_attempts,
|
||||
str(e),
|
||||
self.poll_interval,
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
# If we've exhausted all polling attempts
|
||||
raise Exception(
|
||||
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
|
||||
"The operation may still be running on the server but is taking longer than expected."
|
||||
)
|
||||
@@ -1,19 +1,230 @@
|
||||
from __future__ import annotations
|
||||
from datetime import date
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
||||
from pydantic import BaseModel
|
||||
|
||||
class GeminiSafetyCategory(str, Enum):
|
||||
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
||||
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
||||
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
||||
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
||||
|
||||
|
||||
class GeminiSafetyThreshold(str, Enum):
|
||||
OFF = "OFF"
|
||||
BLOCK_NONE = "BLOCK_NONE"
|
||||
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
||||
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
||||
|
||||
|
||||
class GeminiSafetySetting(BaseModel):
|
||||
category: GeminiSafetyCategory
|
||||
threshold: GeminiSafetyThreshold
|
||||
|
||||
|
||||
class GeminiRole(str, Enum):
|
||||
user = "user"
|
||||
model = "model"
|
||||
|
||||
|
||||
class GeminiMimeType(str, Enum):
|
||||
application_pdf = "application/pdf"
|
||||
audio_mpeg = "audio/mpeg"
|
||||
audio_mp3 = "audio/mp3"
|
||||
audio_wav = "audio/wav"
|
||||
image_png = "image/png"
|
||||
image_jpeg = "image/jpeg"
|
||||
image_webp = "image/webp"
|
||||
text_plain = "text/plain"
|
||||
video_mov = "video/mov"
|
||||
video_mpeg = "video/mpeg"
|
||||
video_mp4 = "video/mp4"
|
||||
video_mpg = "video/mpg"
|
||||
video_avi = "video/avi"
|
||||
video_wmv = "video/wmv"
|
||||
video_mpegps = "video/mpegps"
|
||||
video_flv = "video/flv"
|
||||
|
||||
|
||||
class GeminiInlineData(BaseModel):
|
||||
data: str | None = Field(
|
||||
None,
|
||||
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
|
||||
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
|
||||
)
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
parts: list[GeminiPart] = Field([])
|
||||
role: GeminiRole = Field(..., examples=["user"])
|
||||
|
||||
|
||||
class GeminiSystemInstructionContent(BaseModel):
|
||||
parts: list[GeminiTextPart] = Field(
|
||||
...,
|
||||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
description: str | None = Field(None)
|
||||
name: str = Field(...)
|
||||
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
|
||||
|
||||
|
||||
class GeminiTool(BaseModel):
|
||||
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
|
||||
|
||||
|
||||
class GeminiOffset(BaseModel):
|
||||
nanos: int | None = Field(None, ge=0, le=999999999)
|
||||
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
|
||||
|
||||
|
||||
class GeminiVideoMetadata(BaseModel):
|
||||
endOffset: GeminiOffset | None = Field(None)
|
||||
startOffset: GeminiOffset | None = Field(None)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(1, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(40, ge=1)
|
||||
topP: float | None = Field(0.95, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModel):
|
||||
aspectRatio: str | None = Field(None)
|
||||
imageSize: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: Optional[List[str]] = None
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
imageConfig: GeminiImageConfig | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageGenerateContentRequest(BaseModel):
|
||||
contents: List[GeminiContent]
|
||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
||||
safetySettings: Optional[List[GeminiSafetySetting]] = None
|
||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
||||
contents: list[GeminiContent] = Field(...)
|
||||
generationConfig: GeminiImageGenerationConfig | None = Field(None)
|
||||
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
|
||||
|
||||
class GeminiGenerateContentRequest(BaseModel):
|
||||
contents: list[GeminiContent] = Field(...)
|
||||
generationConfig: GeminiGenerationConfig | None = Field(None)
|
||||
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
|
||||
|
||||
class Modality(str, Enum):
|
||||
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
|
||||
TEXT = "TEXT"
|
||||
IMAGE = "IMAGE"
|
||||
VIDEO = "VIDEO"
|
||||
AUDIO = "AUDIO"
|
||||
DOCUMENT = "DOCUMENT"
|
||||
|
||||
|
||||
class ModalityTokenCount(BaseModel):
|
||||
modality: Modality | None = None
|
||||
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
|
||||
|
||||
|
||||
class Probability(str, Enum):
|
||||
NEGLIGIBLE = "NEGLIGIBLE"
|
||||
LOW = "LOW"
|
||||
MEDIUM = "MEDIUM"
|
||||
HIGH = "HIGH"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
|
||||
|
||||
class GeminiSafetyRating(BaseModel):
|
||||
category: GeminiSafetyCategory | None = None
|
||||
probability: Probability | None = Field(
|
||||
None,
|
||||
description="The probability that the content violates the specified safety category",
|
||||
)
|
||||
|
||||
|
||||
class GeminiCitation(BaseModel):
|
||||
authors: list[str] | None = None
|
||||
endIndex: int | None = None
|
||||
license: str | None = None
|
||||
publicationDate: date | None = None
|
||||
startIndex: int | None = None
|
||||
title: str | None = None
|
||||
uri: str | None = None
|
||||
|
||||
|
||||
class GeminiCitationMetadata(BaseModel):
|
||||
citations: list[GeminiCitation] | None = None
|
||||
|
||||
|
||||
class GeminiCandidate(BaseModel):
|
||||
citationMetadata: GeminiCitationMetadata | None = None
|
||||
content: GeminiContent | None = None
|
||||
finishReason: str | None = None
|
||||
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||
|
||||
|
||||
class GeminiPromptFeedback(BaseModel):
|
||||
blockReason: str | None = None
|
||||
blockReasonMessage: str | None = None
|
||||
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||
|
||||
|
||||
class GeminiUsageMetadata(BaseModel):
|
||||
cachedContentTokenCount: int | None = Field(
|
||||
None,
|
||||
description="Output only. Number of tokens in the cached part in the input (the cached content).",
|
||||
)
|
||||
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
|
||||
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||
None, description="Breakdown of candidate tokens by modality."
|
||||
)
|
||||
promptTokenCount: int | None = Field(
|
||||
None,
|
||||
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
|
||||
)
|
||||
promptTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||
None, description="Breakdown of prompt tokens by modality."
|
||||
)
|
||||
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
|
||||
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
|
||||
|
||||
|
||||
class GeminiGenerateContentResponse(BaseModel):
|
||||
candidates: list[GeminiCandidate] | None = Field(None)
|
||||
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
||||
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
||||
modelVersion: str | None = Field(None)
|
||||
|
||||
120
comfy_api_nodes/apis/minimax_api.py
Normal file
120
comfy_api_nodes/apis/minimax_api.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MinimaxBaseResponse(BaseModel):
|
||||
status_code: int = Field(
|
||||
...,
|
||||
description='Status code. 0 indicates success, other values indicate errors.',
|
||||
)
|
||||
status_msg: str = Field(
|
||||
..., description='Specific error details or success message.'
|
||||
)
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
bytes: Optional[int] = Field(None, description='File size in bytes')
|
||||
created_at: Optional[int] = Field(
|
||||
None, description='Unix timestamp when the file was created, in seconds'
|
||||
)
|
||||
download_url: Optional[str] = Field(
|
||||
None, description='The URL to download the video'
|
||||
)
|
||||
backup_download_url: Optional[str] = Field(
|
||||
None, description='The backup URL to download the video'
|
||||
)
|
||||
|
||||
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
|
||||
filename: Optional[str] = Field(None, description='The name of the file')
|
||||
purpose: Optional[str] = Field(None, description='The purpose of using the file')
|
||||
|
||||
|
||||
class MinimaxFileRetrieveResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
file: File
|
||||
|
||||
|
||||
class MiniMaxModel(str, Enum):
|
||||
T2V_01_Director = 'T2V-01-Director'
|
||||
I2V_01_Director = 'I2V-01-Director'
|
||||
S2V_01 = 'S2V-01'
|
||||
I2V_01 = 'I2V-01'
|
||||
I2V_01_live = 'I2V-01-live'
|
||||
T2V_01 = 'T2V-01'
|
||||
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||
|
||||
|
||||
class Status6(str, Enum):
|
||||
Queueing = 'Queueing'
|
||||
Preparing = 'Preparing'
|
||||
Processing = 'Processing'
|
||||
Success = 'Success'
|
||||
Fail = 'Fail'
|
||||
|
||||
|
||||
class MinimaxTaskResultResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
file_id: Optional[str] = Field(
|
||||
None,
|
||||
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
|
||||
)
|
||||
status: Status6 = Field(
|
||||
...,
|
||||
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
|
||||
)
|
||||
task_id: str = Field(..., description='The task ID being queried.')
|
||||
|
||||
|
||||
class SubjectReferenceItem(BaseModel):
|
||||
image: Optional[str] = Field(
|
||||
None, description='URL or base64 encoding of the subject reference image.'
|
||||
)
|
||||
mask: Optional[str] = Field(
|
||||
None,
|
||||
description='URL or base64 encoding of the mask for the subject reference image.',
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationRequest(BaseModel):
|
||||
callback_url: Optional[str] = Field(
|
||||
None,
|
||||
description='Optional. URL to receive real-time status updates about the video generation task.',
|
||||
)
|
||||
first_frame_image: Optional[str] = Field(
|
||||
None,
|
||||
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||
)
|
||||
model: MiniMaxModel = Field(
|
||||
...,
|
||||
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||
)
|
||||
prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
|
||||
max_length=2000,
|
||||
)
|
||||
prompt_optimizer: Optional[bool] = Field(
|
||||
True,
|
||||
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
|
||||
)
|
||||
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
|
||||
None,
|
||||
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||
)
|
||||
duration: Optional[int] = Field(
|
||||
None,
|
||||
description="The length of the output video in seconds."
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
task_id: str = Field(
|
||||
..., description='The task ID for the asynchronous video generation task.'
|
||||
)
|
||||
133
comfy_api_nodes/apis/topaz_api.py
Normal file
133
comfy_api_nodes/apis/topaz_api.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageEnhanceRequest(BaseModel):
|
||||
model: str = Field("Reimagine")
|
||||
output_format: str = Field("jpeg")
|
||||
subject_detection: str = Field("All")
|
||||
face_enhancement: bool = Field(True)
|
||||
face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false")
|
||||
face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false")
|
||||
source_url: str = Field(...)
|
||||
output_width: Optional[int] = Field(None)
|
||||
output_height: Optional[int] = Field(None)
|
||||
crop_to_fill: bool = Field(False)
|
||||
prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance")
|
||||
creativity: int = Field(3, description="Creativity settings range from 1 to 9")
|
||||
face_preservation: str = Field("true", description="To preserve the identity of characters")
|
||||
color_preservation: str = Field("true", description="To preserve the original color")
|
||||
|
||||
|
||||
class ImageAsyncTaskResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
|
||||
|
||||
class ImageStatusResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
credits: int = Field(...)
|
||||
|
||||
|
||||
class ImageDownloadResponse(BaseModel):
|
||||
download_url: str = Field(...)
|
||||
expiry: int = Field(...)
|
||||
|
||||
|
||||
class Resolution(BaseModel):
|
||||
width: int = Field(...)
|
||||
height: int = Field(...)
|
||||
|
||||
|
||||
class CreateCreateVideoRequestSource(BaseModel):
|
||||
container: str = Field(...)
|
||||
size: int = Field(..., description="Size of the video file in bytes")
|
||||
duration: int = Field(..., description="Duration of the video file in seconds")
|
||||
frameCount: int = Field(..., description="Total number of frames in the video")
|
||||
frameRate: int = Field(...)
|
||||
resolution: Resolution = Field(...)
|
||||
|
||||
|
||||
class VideoFrameInterpolationFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
slowmo: Optional[int] = Field(None)
|
||||
fps: int = Field(...)
|
||||
duplicate: bool = Field(...)
|
||||
duplicate_threshold: float = Field(...)
|
||||
|
||||
|
||||
class VideoEnhancementFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
auto: Optional[str] = Field(None, description="Auto, Manual, Relative")
|
||||
focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects")
|
||||
compression: Optional[float] = Field(None, description="Strength of compression recovery")
|
||||
details: Optional[float] = Field(None, description="Amount of detail reconstruction")
|
||||
prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing")
|
||||
noise: Optional[float] = Field(None, description="Amount of noise reduction")
|
||||
halo: Optional[float] = Field(None, description="Amount of halo reduction")
|
||||
preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength")
|
||||
blur: Optional[float] = Field(None, description="Amount of sharpness applied")
|
||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||
|
||||
|
||||
class OutputInformationVideo(BaseModel):
|
||||
resolution: Resolution = Field(...)
|
||||
frameRate: int = Field(...)
|
||||
audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert")
|
||||
audioTransfer: str = Field(..., description="Copy, Convert, None")
|
||||
dynamicCompressionLevel: str = Field(..., description="Low, Mid, High")
|
||||
|
||||
|
||||
class Overrides(BaseModel):
|
||||
isPaidDiffusion: bool = Field(True)
|
||||
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateCreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
|
||||
|
||||
class CreateVideoResponse(BaseModel):
|
||||
requestId: str = Field(...)
|
||||
|
||||
|
||||
class VideoAcceptResponse(BaseModel):
|
||||
uploadId: str = Field(...)
|
||||
urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequestPart(BaseModel):
|
||||
partNum: int = Field(...)
|
||||
eTag: str = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequest(BaseModel):
|
||||
uploadResults: list[VideoCompleteUploadRequestPart] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadResponse(BaseModel):
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
|
||||
|
||||
class VideoStatusResponseEstimates(BaseModel):
|
||||
cost: list[int] = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponseDownloadUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
estimates: Optional[VideoStatusResponseEstimates] = Field(None)
|
||||
progress: Optional[float] = Field(None)
|
||||
message: Optional[str] = Field("")
|
||||
download: Optional[VideoStatusResponseDownloadUrl] = Field(None)
|
||||
@@ -1,13 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoModelVersion,
|
||||
TripoTextureQuality,
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
v1_4_20240625 = 'v1.4-20240625'
|
||||
|
||||
|
||||
class TripoTextureQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
|
||||
|
||||
class TripoStyle(str, Enum):
|
||||
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||
ANIMAL_VENOM = "animal:venom"
|
||||
|
||||
111
comfy_api_nodes/apis/veo_api.py
Normal file
111
comfy_api_nodes/apis/veo_api.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Image2(BaseModel):
|
||||
bytesBase64Encoded: str
|
||||
gcsUri: Optional[str] = None
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Image3(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = None
|
||||
gcsUri: str
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Instance1(BaseModel):
|
||||
image: Optional[Union[Image2, Image3]] = Field(
|
||||
None, description='Optional image to guide video generation'
|
||||
)
|
||||
prompt: str = Field(..., description='Text description of the video')
|
||||
|
||||
|
||||
class PersonGeneration1(str, Enum):
|
||||
ALLOW = 'ALLOW'
|
||||
BLOCK = 'BLOCK'
|
||||
|
||||
|
||||
class Parameters1(BaseModel):
|
||||
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||
durationSeconds: Optional[int] = None
|
||||
enhancePrompt: Optional[bool] = None
|
||||
generateAudio: Optional[bool] = Field(
|
||||
None,
|
||||
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||
)
|
||||
negativePrompt: Optional[str] = None
|
||||
personGeneration: Optional[PersonGeneration1] = None
|
||||
sampleCount: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
storageUri: Optional[str] = Field(
|
||||
None, description='Optional Cloud Storage URI to upload the video'
|
||||
)
|
||||
|
||||
|
||||
class VeoGenVidRequest(BaseModel):
|
||||
instances: Optional[list[Instance1]] = None
|
||||
parameters: Optional[Parameters1] = None
|
||||
|
||||
|
||||
class VeoGenVidResponse(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description='Operation resource name',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class VeoGenVidPollRequest(BaseModel):
|
||||
operationName: str = Field(
|
||||
...,
|
||||
description='Full operation name (from predict response)',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Video(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = Field(
|
||||
None, description='Base64-encoded video content'
|
||||
)
|
||||
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
|
||||
mimeType: Optional[str] = Field(None, description='Video MIME type')
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
code: Optional[int] = Field(None, description='Error code')
|
||||
message: Optional[str] = Field(None, description='Error message')
|
||||
|
||||
|
||||
class Response1(BaseModel):
|
||||
field_type: Optional[str] = Field(
|
||||
None,
|
||||
alias='@type',
|
||||
examples=[
|
||||
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
|
||||
],
|
||||
)
|
||||
raiMediaFilteredCount: Optional[int] = Field(
|
||||
None, description='Count of media filtered by responsible AI policies'
|
||||
)
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
done: Optional[bool] = None
|
||||
error: Optional[Error1] = Field(
|
||||
None, description='Error details if operation failed'
|
||||
)
|
||||
name: Optional[str] = None
|
||||
response: Optional[Response1] = Field(
|
||||
None, description='The actual prediction response if done is true'
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_as_bytesio,
|
||||
resize_mask_to_image,
|
||||
sync_op,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
@@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
@@ -233,89 +227,76 @@ async def download_and_process_images(image_urls):
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
def display_image_urls_on_node(image_urls, node_id):
|
||||
if node_id and image_urls:
|
||||
if len(image_urls) == 1:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
||||
)
|
||||
else:
|
||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
||||
)
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(comfy_io.ComfyNode):
|
||||
class IdeogramV1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -334,77 +315,63 @@ class IdeogramV1(comfy_io.ComfyNode):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV2(comfy_io.ComfyNode):
|
||||
class IdeogramV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=list(V1_V1_RES_MAP.keys()),
|
||||
default="Auto",
|
||||
@@ -412,44 +379,44 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"style_type",
|
||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
default="NONE",
|
||||
tooltip="Style type for generation (V2 only)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
#"color_palette": (
|
||||
@@ -462,12 +429,12 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
#),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -500,18 +467,11 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
@@ -519,36 +479,28 @@ class IdeogramV2(comfy_io.ComfyNode):
|
||||
seed=seed,
|
||||
aspect_ratio=final_aspect_ratio,
|
||||
resolution=final_resolution,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
style_type=style_type if style_type != "NONE" else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV3(comfy_io.ComfyNode):
|
||||
class IdeogramV3(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="api node/image/Ideogram",
|
||||
@@ -556,30 +508,30 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
is_api_node=True,
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or editing",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image for image editing.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Mask.Input(
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V3_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=V3_RESOLUTIONS,
|
||||
default="Auto",
|
||||
@@ -587,57 +539,57 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
"If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"rendering_speed",
|
||||
options=["DEFAULT", "TURBO", "QUALITY"],
|
||||
default="DEFAULT",
|
||||
tooltip="Controls the trade-off between generation speed and quality",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"character_image",
|
||||
tooltip="Image to use as character reference.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Mask.Input(
|
||||
IO.Mask.Input(
|
||||
"character_mask",
|
||||
tooltip="Optional mask for character reference image.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -656,10 +608,6 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
character_image=None,
|
||||
character_mask=None,
|
||||
):
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
if rendering_speed == "BALANCED": # for backward compatibility
|
||||
rendering_speed = "DEFAULT"
|
||||
|
||||
@@ -694,9 +642,6 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
||||
|
||||
# Process image and mask
|
||||
input_tensor = image.squeeze().cpu()
|
||||
# Resize mask to match image dimension
|
||||
@@ -749,27 +694,20 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
if character_mask_binary:
|
||||
files["character_mask_binary"] = character_mask_binary
|
||||
|
||||
# Execute the operation for edit mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3EditRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=edit_request,
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=edit_request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
# If only one of image or mask is provided, raise an error
|
||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||
else:
|
||||
# Generation mode
|
||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
||||
|
||||
# Create generation request
|
||||
gen_request = IdeogramV3Request(
|
||||
prompt=prompt,
|
||||
@@ -800,43 +738,34 @@ class IdeogramV3(comfy_io.ComfyNode):
|
||||
if files:
|
||||
gen_request.style_type = "AUTO"
|
||||
|
||||
# Execute the operation for generation mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3Request,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=gen_request,
|
||||
files=files if files else None,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> IdeogramExtension:
|
||||
return IdeogramExtension()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
199
comfy_api_nodes/nodes_ltxv.py
Normal file
199
comfy_api_nodes/nodes_ltxv.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op_raw,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
MODELS_MAP = {
|
||||
"LTX-2 (Pro)": "ltx-2-pro",
|
||||
"LTX-2 (Fast)": "ltx-2-fast",
|
||||
}
|
||||
|
||||
|
||||
class ExecuteTaskRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiTextToVideo",
|
||||
display_name="LTXV Text To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||
raise ValueError(
|
||||
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||
)
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiImageToVideo",
|
||||
display_name="LTXV Image To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||
raise ValueError(
|
||||
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Currently only one input image is supported.")
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TextToVideoNode,
|
||||
ImageToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LtxvApiExtension:
|
||||
return LtxvApiExtension()
|
||||
@@ -1,75 +1,57 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
LumaImageModel,
|
||||
LumaVideoModel,
|
||||
LumaVideoOutputResolution,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaAspectRatio,
|
||||
LumaState,
|
||||
LumaImageGenerationRequest,
|
||||
LumaGenerationRequest,
|
||||
LumaGeneration,
|
||||
LumaCharacterRef,
|
||||
LumaModifyImageRef,
|
||||
LumaConceptChain,
|
||||
LumaGeneration,
|
||||
LumaGenerationRequest,
|
||||
LumaImageGenerationRequest,
|
||||
LumaImageIdentity,
|
||||
LumaImageModel,
|
||||
LumaImageReference,
|
||||
LumaIO,
|
||||
LumaKeyframes,
|
||||
LumaModifyImageRef,
|
||||
LumaReference,
|
||||
LumaReferenceChain,
|
||||
LumaImageReference,
|
||||
LumaKeyframes,
|
||||
LumaConceptChain,
|
||||
LumaIO,
|
||||
LumaVideoModel,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaVideoOutputResolution,
|
||||
get_luma_concepts,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
LUMA_T2V_AVERAGE_DURATION = 105
|
||||
LUMA_I2V_AVERAGE_DURATION = 100
|
||||
|
||||
def image_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
||||
|
||||
def video_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
||||
|
||||
class LumaReferenceNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Holds an image and weight for use with Luma Generate Image node.
|
||||
"""
|
||||
|
||||
class LumaReferenceNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaReferenceNode",
|
||||
display_name="Luma Reference",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Image to use as reference.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"weight",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
@@ -77,72 +59,56 @@ class LumaReferenceNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Weight of image reference.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
||||
IO.Custom(LumaIO.LUMA_REF).Input(
|
||||
"luma_ref",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||
) -> comfy_io.NodeOutput:
|
||||
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
|
||||
if luma_ref is not None:
|
||||
luma_ref = luma_ref.clone()
|
||||
else:
|
||||
luma_ref = LumaReferenceChain()
|
||||
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
||||
return comfy_io.NodeOutput(luma_ref)
|
||||
return IO.NodeOutput(luma_ref)
|
||||
|
||||
|
||||
class LumaConceptsNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
||||
"""
|
||||
|
||||
class LumaConceptsNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaConceptsNode",
|
||||
display_name="Luma Concepts",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept1",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept2",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept3",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"concept4",
|
||||
options=get_luma_concepts(include_none=True),
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to add to the ones chosen here.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -153,42 +119,38 @@ class LumaConceptsNode(comfy_io.ComfyNode):
|
||||
concept3: str,
|
||||
concept4: str,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
||||
if luma_concepts is not None:
|
||||
chain = luma_concepts.clone_and_merge(chain)
|
||||
return comfy_io.NodeOutput(chain)
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
class LumaImageGenerationNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode",
|
||||
display_name="Luma Text to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=LumaImageModel,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=LumaAspectRatio,
|
||||
default=LumaAspectRatio.ratio_16_9,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -196,7 +158,7 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"style_image_weight",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
@@ -204,27 +166,27 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Weight of style image. Ignored if no style_image provided.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
||||
IO.Custom(LumaIO.LUMA_REF).Input(
|
||||
"image_luma_ref",
|
||||
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"style_image",
|
||||
tooltip="Style reference image; only 1 image will be used.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"character_image",
|
||||
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Image.Output()],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -237,45 +199,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
seed,
|
||||
style_image_weight: float,
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
image_luma_ref: Optional[LumaReferenceChain] = None,
|
||||
style_image: Optional[torch.Tensor] = None,
|
||||
character_image: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = await cls._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = await cls._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
|
||||
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=aspect_ratio,
|
||||
@@ -283,41 +230,21 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return comfy_io.NodeOutput(img)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||
|
||||
@classmethod
|
||||
async def _convert_luma_refs(
|
||||
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
|
||||
luma_urls.append(download_urls[0])
|
||||
ref_count += 1
|
||||
if ref_count >= max_refs:
|
||||
@@ -325,38 +252,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
@classmethod
|
||||
async def _convert_style_image(
|
||||
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
|
||||
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
|
||||
return await cls._convert_luma_refs(chain, max_refs=1)
|
||||
|
||||
|
||||
class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Modifies images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
class LumaImageModifyNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageModifyNode",
|
||||
display_name="Luma Image to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"image_weight",
|
||||
default=0.1,
|
||||
min=0.0,
|
||||
@@ -364,11 +283,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=LumaImageModel,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -377,11 +296,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Image.Output()],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -394,99 +313,68 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
||||
image: torch.Tensor,
|
||||
image_weight: float,
|
||||
seed,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# first, upload image
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
) -> IO.NodeOutput:
|
||||
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
modify_image_ref=LumaModifyImageRef(
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return comfy_io.NodeOutput(img)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||
|
||||
|
||||
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaVideoNode",
|
||||
display_name="Luma Text to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on prompt and output_size.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=LumaVideoModel,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=LumaAspectRatio,
|
||||
default=LumaAspectRatio.ratio_16_9,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=LumaVideoOutputResolution,
|
||||
default=LumaVideoOutputResolution.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=LumaVideoModelOutputDuration,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -494,17 +382,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
luma_concepts: Optional[LumaConceptChain] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
@@ -545,77 +426,55 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on prompt, input images, and output_size.
|
||||
"""
|
||||
|
||||
class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageToVideoNode",
|
||||
display_name="Luma Image to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=LumaVideoModel,
|
||||
),
|
||||
# comfy_io.Combo.Input(
|
||||
# IO.Combo.Input(
|
||||
# "aspect_ratio",
|
||||
# options=[ratio.value for ratio in LumaAspectRatio],
|
||||
# default=LumaAspectRatio.ratio_16_9,
|
||||
# ),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=LumaVideoOutputResolution,
|
||||
default=LumaVideoOutputResolution.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -623,27 +482,27 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"first_image",
|
||||
tooltip="First frame of generated video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"last_image",
|
||||
tooltip="Last frame of generated video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -660,27 +519,17 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
|
||||
raise Exception("At least one of first_image and last_image requires an input.")
|
||||
keyframes = await cls._convert_to_keyframes(first_image, last_image)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||
@@ -690,61 +539,38 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||
|
||||
@classmethod
|
||||
async def _convert_to_keyframes(
|
||||
cls,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
return None
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
LumaImageGenerationNode,
|
||||
LumaImageModifyNode,
|
||||
|
||||
@@ -1,71 +1,57 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
import logging
|
||||
import torch
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.minimax_api import (
|
||||
MinimaxFileRetrieveResponse,
|
||||
MiniMaxModel,
|
||||
MinimaxTaskResultResponse,
|
||||
MinimaxVideoGenerationRequest,
|
||||
MinimaxVideoGenerationResponse,
|
||||
MinimaxFileRetrieveResponse,
|
||||
MinimaxTaskResultResponse,
|
||||
SubjectReferenceItem,
|
||||
MiniMaxModel,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
I2V_AVERAGE_DURATION = 114
|
||||
T2V_AVERAGE_DURATION = 234
|
||||
|
||||
|
||||
async def _generate_mm_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
*,
|
||||
auth: dict[str, str],
|
||||
node_id: str,
|
||||
prompt_text: str,
|
||||
seed: int,
|
||||
model: str,
|
||||
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||
average_duration: Optional[int] = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
if image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
|
||||
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
data=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
@@ -73,95 +59,64 @@ async def _generate_mm_video(
|
||||
subject_reference=subject_reference,
|
||||
prompt_optimizer=None,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
task_result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=node_id,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=auth,
|
||||
file_result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info("Generated video URL: %s", file_url)
|
||||
if node_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, node_id)
|
||||
|
||||
# Download and return as VideoFromFile
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||
if file_result.file.backup_download_url:
|
||||
try:
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||
|
||||
|
||||
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxTextToVideoNode",
|
||||
display_name="MiniMax Text to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["T2V-01", "T2V-01-Director"],
|
||||
default="T2V-01",
|
||||
tooltip="Model to use for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -172,11 +127,11 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -187,13 +142,9 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
model: str = "T2V-01",
|
||||
seed: int = 0,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@@ -203,36 +154,32 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxImageToVideoNode",
|
||||
display_name="MiniMax Image to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Image to use as first frame of video generation",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
|
||||
default="I2V-01",
|
||||
tooltip="Model to use for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -243,11 +190,11 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -259,13 +206,9 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
model: str = "I2V-01",
|
||||
seed: int = 0,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@@ -275,36 +218,32 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxSubjectToVideoNode",
|
||||
display_name="MiniMax Subject to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"subject",
|
||||
tooltip="Image of subject to reference for video generation",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["S2V-01"],
|
||||
default="S2V-01",
|
||||
tooltip="Model to use for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -315,11 +254,11 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -331,13 +270,9 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
model: str = "S2V-01",
|
||||
seed: int = 0,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@@ -347,24 +282,22 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
||||
|
||||
class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxHailuoVideoNode",
|
||||
display_name="MiniMax Hailuo Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt to guide the video generation.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -374,25 +307,25 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"first_frame_image",
|
||||
tooltip="Optional image to use as the first frame to generate a video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"prompt_optimizer",
|
||||
default=True,
|
||||
tooltip="Optimize prompt to improve generation quality when needed.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=[6, 10],
|
||||
default=6,
|
||||
tooltip="The length of the output video in seconds.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["768P", "1080P"],
|
||||
default="768P",
|
||||
@@ -400,11 +333,11 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -419,11 +352,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
duration: int = 6,
|
||||
resolution: str = "768P",
|
||||
model: str = "MiniMax-Hailuo-02",
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
) -> IO.NodeOutput:
|
||||
if first_frame_image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
|
||||
@@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if first_frame_image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
|
||||
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
data=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
@@ -453,72 +379,47 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
average_duration = 120 if resolution == "768P" else 240
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
task_result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=auth,
|
||||
file_result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info("Generated video URL: %s", file_url)
|
||||
if cls.hidden.unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
|
||||
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
if file_result.file.backup_download_url:
|
||||
try:
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||
|
||||
|
||||
class MinimaxExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
MinimaxTextToVideoNode,
|
||||
MinimaxImageToVideoNode,
|
||||
|
||||
@@ -1,33 +1,30 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
MoonvalleyPromptResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
|
||||
import av
|
||||
import io
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
trim_video,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||
@@ -50,13 +47,6 @@ MAX_VID_HEIGHT = 10000
|
||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||
|
||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class MoonvalleyApiError(Exception):
|
||||
"""Base exception for Moonvalley API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||
@@ -68,64 +58,7 @@ def validate_task_creation_response(response) -> None:
|
||||
if not is_valid_task_creation_response(response):
|
||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||
logging.error(error_msg)
|
||||
raise MoonvalleyApiError(error_msg)
|
||||
|
||||
|
||||
def get_video_from_response(response):
|
||||
video = response.output_url
|
||||
logging.info(
|
||||
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
|
||||
)
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Moonvalley video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response:
|
||||
return str(get_video_from_response(response))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
],
|
||||
max_poll_attempts=240, # 64 minutes with 16s interval
|
||||
poll_interval=16.0,
|
||||
failed_statuses=["error"],
|
||||
status_extractor=lambda response: (
|
||||
response.status if response and response.status else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
def validate_prompts(
|
||||
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
|
||||
):
|
||||
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
|
||||
if not prompt:
|
||||
raise ValueError("Positive prompt is empty")
|
||||
if len(prompt) > max_length:
|
||||
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
|
||||
if negative_prompt and len(negative_prompt) > max_length:
|
||||
raise ValueError(
|
||||
f"Negative prompt is too long: {len(negative_prompt)} characters"
|
||||
)
|
||||
return True
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
@@ -144,7 +77,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
"""
|
||||
width, height = _get_video_dimensions(video)
|
||||
_validate_video_dimensions(width, height)
|
||||
_validate_container_format(video)
|
||||
validate_container_format_is_mp4(video)
|
||||
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
@@ -169,21 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ", ".join(
|
||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_container_format(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(
|
||||
f"Only MP4 container format supported. Got: {container_format}"
|
||||
)
|
||||
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
@@ -196,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_minimum_duration(duration: float) -> None:
|
||||
"""Ensures video is at least 5 seconds long."""
|
||||
if duration < 5:
|
||||
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
@@ -206,123 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
return video
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream(
|
||||
"h264", rate=stream.average_rate
|
||||
)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info(
|
||||
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
|
||||
)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream(
|
||||
"aac", rate=stream.sample_rate
|
||||
)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (
|
||||
estimated_frames // 16
|
||||
) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def parse_width_height_from_res(resolution: str):
|
||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||
res_map = {
|
||||
@@ -346,41 +149,36 @@ def parse_control_parameter(value):
|
||||
return control_map.get(value, control_map["Motion Transfer"])
|
||||
|
||||
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=node_id,
|
||||
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
status_extractor=lambda r: (r.status if r and r.status else None),
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
|
||||
|
||||
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyImg2VideoNode",
|
||||
display_name="Moonvalley Marey Image to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="Moonvalley Marey Image to Video Node",
|
||||
inputs=[
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
@@ -391,7 +189,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"16:9 (1920 x 1080)",
|
||||
@@ -404,7 +202,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
default="16:9 (1920 x 1080)",
|
||||
tooltip="Resolution of the output video",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"prompt_adherence",
|
||||
default=4.5,
|
||||
min=1.0,
|
||||
@@ -412,17 +210,17 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
step=1.0,
|
||||
tooltip="Guidance scale for generation control",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
@@ -431,11 +229,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
tooltip="Number of denoising steps",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -450,16 +248,12 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
prompt_adherence: float,
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
@@ -472,53 +266,37 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
||||
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
|
||||
)
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_IMG2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return comfy_io.NodeOutput(video)
|
||||
return IO.NodeOutput(video)
|
||||
|
||||
|
||||
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyVideo2VideoNode",
|
||||
display_name="Moonvalley Marey Video to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Describes the video to generate",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
@@ -529,28 +307,28 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=False,
|
||||
),
|
||||
comfy_io.Video.Input(
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
|
||||
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"control_type",
|
||||
options=["Motion Transfer", "Pose Transfer"],
|
||||
default="Motion Transfer",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"motion_intensity",
|
||||
default=100,
|
||||
min=0,
|
||||
@@ -559,21 +337,21 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
tooltip="Only used if control_type is 'Motion Transfer'",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Number of inference steps",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -589,16 +367,11 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
motion_intensity: Optional[int] = 100,
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
) -> IO.NodeOutput:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
|
||||
|
||||
validate_prompts(prompt, negative_prompt)
|
||||
video_url = await upload_video_to_comfyapi(cls, validated_video)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
@@ -613,52 +386,37 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
||||
guidance_scale=prompt_adherence,
|
||||
)
|
||||
|
||||
control = parse_control_parameter(control_type)
|
||||
|
||||
request = MoonvalleyVideoToVideoRequest(
|
||||
control_type=control,
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_VIDEO2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyVideoToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyVideoToVideoRequest(
|
||||
control_type=parse_control_parameter(control_type),
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return comfy_io.NodeOutput(video)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyTxt2VideoNode",
|
||||
display_name="Moonvalley Marey Text to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
@@ -669,7 +427,7 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"16:9 (1920 x 1080)",
|
||||
@@ -682,7 +440,7 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
default="16:9 (1920 x 1080)",
|
||||
tooltip="Resolution of the output video",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"prompt_adherence",
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
@@ -690,17 +448,17 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
step=1.0,
|
||||
tooltip="Guidance scale for generation control",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Random seed value",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
@@ -709,11 +467,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
tooltip="Inference steps",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -727,15 +485,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
prompt_adherence: float,
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
@@ -745,35 +499,21 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
||||
width=width_height["width"],
|
||||
height=width_height["height"],
|
||||
)
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
|
||||
init_op = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_TXT2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
|
||||
)
|
||||
task_creation_response = await init_op.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return comfy_io.NodeOutput(video)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
MoonvalleyImg2VideoNode,
|
||||
MoonvalleyTxt2VideoNode,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,27 +7,23 @@ from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, comfy_io
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||
from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
from comfy_api_nodes.apis import pika_defs
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
EmptyRequest,
|
||||
HttpMethod,
|
||||
PollingOperation,
|
||||
SynchronousOperation,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
@@ -43,67 +39,57 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
async def execute_task(
|
||||
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
task_id = (await initial_operation.execute()).video_id
|
||||
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
),
|
||||
completed_statuses=["finished"],
|
||||
failed_statuses=["failed", "cancelled"],
|
||||
task_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
) -> IO.NodeOutput:
|
||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
|
||||
node_id=node_id,
|
||||
estimated_duration=60,
|
||||
max_poll_attempts=240,
|
||||
).execute()
|
||||
)
|
||||
if not final_response.url:
|
||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
video_url = final_response.url
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(video_url))
|
||||
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
||||
|
||||
|
||||
def get_base_inputs_types() -> list[comfy_io.Input]:
|
||||
def get_base_inputs_types() -> list[IO.Input]:
|
||||
"""Get the base required inputs types common to all Pika nodes."""
|
||||
return [
|
||||
comfy_io.String.Input("prompt_text", multiline=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
comfy_io.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||
comfy_io.Combo.Input("duration", options=[5, 10], default=5),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||
IO.Combo.Input("duration", options=[5, 10], default=5),
|
||||
]
|
||||
|
||||
|
||||
class PikaImageToVideo(comfy_io.ComfyNode):
|
||||
class PikaImageToVideo(IO.ComfyNode):
|
||||
"""Pika 2.2 Image to Video Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaImageToVideoNode2_2",
|
||||
display_name="Pika Image to Video",
|
||||
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image", tooltip="The image to convert to video"),
|
||||
IO.Image.Input("image", tooltip="The image to convert to video"),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -117,7 +103,7 @@ class PikaImageToVideo(comfy_io.ComfyNode):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
@@ -127,38 +113,30 @@ class PikaImageToVideo(comfy_io.ComfyNode):
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaTextToVideoNode(comfy_io.ComfyNode):
|
||||
class PikaTextToVideoNode(IO.ComfyNode):
|
||||
"""Pika Text2Video v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaTextToVideoNode2_2",
|
||||
display_name="Pika Text to Video",
|
||||
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
@@ -167,11 +145,11 @@ class PikaTextToVideoNode(comfy_io.ComfyNode):
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
)
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -185,19 +163,12 @@ class PikaTextToVideoNode(comfy_io.ComfyNode):
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@@ -205,30 +176,29 @@ class PikaTextToVideoNode(comfy_io.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaScenes(comfy_io.ComfyNode):
|
||||
class PikaScenes(IO.ComfyNode):
|
||||
"""PikaScenes v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaScenesV2_2",
|
||||
display_name="Pika Scenes (Video Image Composition)",
|
||||
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"ingredients_mode",
|
||||
options=["creative", "precise"],
|
||||
default="creative",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
@@ -236,37 +206,37 @@ class PikaScenes(comfy_io.ComfyNode):
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_1",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_2",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_3",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_4",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image_ingredient_5",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -286,7 +256,7 @@ class PikaScenes(comfy_io.ComfyNode):
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
all_image_bytes_io = []
|
||||
for image in [
|
||||
image_ingredient_1,
|
||||
@@ -312,53 +282,45 @@ class PikaScenes(comfy_io.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASCENES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikAdditionsNode(comfy_io.ComfyNode):
|
||||
class PikAdditionsNode(IO.ComfyNode):
|
||||
"""Pika Pikadditions Node. Add an image into a video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikadditions",
|
||||
display_name="Pikadditions (Video Object Insertion)",
|
||||
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Video.Input("video", tooltip="The video to add an image to."),
|
||||
comfy_io.Image.Input("image", tooltip="The image to add to the video."),
|
||||
comfy_io.String.Input("prompt_text", multiline=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||
comfy_io.Int.Input(
|
||||
IO.Video.Input("video", tooltip="The video to add an image to."),
|
||||
IO.Image.Input("image", tooltip="The image to add to the video."),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -371,7 +333,7 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
@@ -386,63 +348,55 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaSwapsNode(comfy_io.ComfyNode):
|
||||
class PikaSwapsNode(IO.ComfyNode):
|
||||
"""Pika Pikaswaps Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaswaps",
|
||||
display_name="Pika Swaps (Video Object Replacement)",
|
||||
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
|
||||
comfy_io.Image.Input(
|
||||
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The image used to replace the masked object in the video.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Mask.Input(
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Use the mask to define areas in the video to replace.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input("prompt_text", multiline=True, optional=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input("prompt_text", multiline=True, optional=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||
IO.String.Input(
|
||||
"region_to_modify",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Plaintext description of the object / region to modify.",
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -457,7 +411,7 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
||||
negative_prompt: str = "",
|
||||
seed: int = 0,
|
||||
region_to_modify: str = "",
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
@@ -475,49 +429,41 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
||||
seed=seed,
|
||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASWAPS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaffectsNode(comfy_io.ComfyNode):
|
||||
class PikaffectsNode(IO.ComfyNode):
|
||||
"""Pika Pikaffects Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaffects",
|
||||
display_name="Pikaffects (Video Effects)",
|
||||
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||
IO.Combo.Input(
|
||||
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||
),
|
||||
comfy_io.String.Input("prompt_text", multiline=True),
|
||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -530,19 +476,12 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
@@ -550,31 +489,30 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode(comfy_io.ComfyNode):
|
||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
"""PikaFrames v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaStartEndFrameNode2_2",
|
||||
display_name="Pika Start and End Frame to Video",
|
||||
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image_start", tooltip="The first image to combine."),
|
||||
comfy_io.Image.Input("image_end", tooltip="The last image to combine."),
|
||||
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
||||
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -589,23 +527,17 @@ class PikaStartEndFrameNode(comfy_io.ComfyNode):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
||||
pika_files = [
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFRAMES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@@ -614,14 +546,13 @@ class PikaStartEndFrameNode(comfy_io.ComfyNode):
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaApiNodesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PikaImageToVideo,
|
||||
PikaTextToVideoNode,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from io import BytesIO
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseTextVideoRequest,
|
||||
PixverseImageVideoRequest,
|
||||
@@ -17,125 +16,91 @@ from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseIO,
|
||||
pixverse_templates,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
|
||||
import torch
|
||||
import aiohttp
|
||||
|
||||
|
||||
AVERAGE_DURATION_T2V = 32
|
||||
AVERAGE_DURATION_I2V = 30
|
||||
AVERAGE_DURATION_T2T = 52
|
||||
|
||||
|
||||
def get_video_url_from_response(
|
||||
response: PixverseGenerationStatusResponse,
|
||||
) -> Optional[str]:
|
||||
if response.Resp is None or response.Resp.url is None:
|
||||
return None
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/image/upload",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseImageUploadResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
|
||||
response_upload = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
|
||||
response_model=PixverseImageUploadResponse,
|
||||
files={"image": tensor_to_bytesio(image)},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(
|
||||
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
|
||||
)
|
||||
|
||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||
return response_upload.Resp.img_id
|
||||
|
||||
|
||||
class PixverseTemplateNode(comfy_io.ComfyNode):
|
||||
class PixverseTemplateNode(IO.ComfyNode):
|
||||
"""
|
||||
Select template for PixVerse Video generation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PixverseTemplateNode",
|
||||
display_name="PixVerse Template",
|
||||
category="api node/video/PixVerse",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input("template", options=list(pixverse_templates.keys())),
|
||||
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
|
||||
],
|
||||
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
|
||||
outputs=[IO.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, template: str) -> comfy_io.NodeOutput:
|
||||
def execute(cls, template: str) -> IO.NodeOutput:
|
||||
template_id = pixverse_templates.get(template, None)
|
||||
if template_id is None:
|
||||
raise Exception(f"Template '{template}' is not recognized.")
|
||||
# just return the integer
|
||||
return comfy_io.NodeOutput(template_id)
|
||||
return IO.NodeOutput(template_id)
|
||||
|
||||
|
||||
class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PixverseTextToVideoNode",
|
||||
display_name="PixVerse Text to Video",
|
||||
category="api node/video/PixVerse",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=PixverseAspectRatio,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"quality",
|
||||
options=PixverseQuality,
|
||||
default=PixverseQuality.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration_seconds",
|
||||
options=PixverseDuration,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"motion_mode",
|
||||
options=PixverseMotionMode,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -143,24 +108,24 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
|
||||
IO.Custom(PixverseIO.TEMPLATE).Input(
|
||||
"pixverse_template",
|
||||
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -176,8 +141,8 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
@@ -186,18 +151,11 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/text/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTextVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTextVideoRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
|
||||
response_model=PixverseVideoResponse,
|
||||
data=PixverseTextVideoRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
@@ -207,20 +165,14 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
@@ -228,52 +180,41 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PixverseImageToVideoNode",
|
||||
display_name="PixVerse Image to Video",
|
||||
category="api node/video/PixVerse",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image"),
|
||||
comfy_io.String.Input(
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"quality",
|
||||
options=PixverseQuality,
|
||||
default=PixverseQuality.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration_seconds",
|
||||
options=PixverseDuration,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"motion_mode",
|
||||
options=PixverseMotionMode,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -281,24 +222,24 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
|
||||
IO.Custom(PixverseIO.TEMPLATE).Input(
|
||||
"pixverse_template",
|
||||
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -314,13 +255,9 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
|
||||
img_id = await upload_image_to_pixverse(cls, image)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -330,14 +267,11 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/img/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseImageVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseImageVideoRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
|
||||
response_model=PixverseVideoResponse,
|
||||
data=PixverseImageVideoRequest(
|
||||
img_id=img_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
@@ -347,20 +281,15 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
@@ -368,53 +297,42 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||
|
||||
|
||||
class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PixverseTransitionVideoNode",
|
||||
display_name="PixVerse Transition Video",
|
||||
category="api node/video/PixVerse",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
comfy_io.Image.Input("first_frame"),
|
||||
comfy_io.Image.Input("last_frame"),
|
||||
comfy_io.String.Input(
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input("last_frame"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the video generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"quality",
|
||||
options=PixverseQuality,
|
||||
default=PixverseQuality.res_540p,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration_seconds",
|
||||
options=PixverseDuration,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"motion_mode",
|
||||
options=PixverseMotionMode,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
@@ -422,7 +340,7 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
@@ -430,11 +348,11 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.Video.Output()],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -450,14 +368,10 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
|
||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
|
||||
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
|
||||
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -467,14 +381,11 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/transition/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTransitionVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTransitionVideoRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
|
||||
response_model=PixverseVideoResponse,
|
||||
data=PixverseTransitionVideoRequest(
|
||||
first_frame_img=first_frame_id,
|
||||
last_frame_img=last_frame_id,
|
||||
prompt=prompt,
|
||||
@@ -484,20 +395,15 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
@@ -505,21 +411,14 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||
|
||||
|
||||
class PixVerseExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PixverseTextToVideoNode,
|
||||
PixverseImageToVideoNode,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
import folder_paths as comfy_paths
|
||||
import aiohttp
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
@@ -26,26 +23,26 @@ from comfy_api_nodes.apis.rodin_api import (
|
||||
Rodin3DDownloadResponse,
|
||||
JobStatus,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
download_url_to_bytesio,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
|
||||
COMMON_PARAMETERS = [
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"Seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=65535,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||
IO.Combo.Input(
|
||||
"Polygon_count",
|
||||
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
|
||||
default="18K-Quad",
|
||||
@@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
|
||||
|
||||
async def create_generate_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
images=None,
|
||||
seed=1,
|
||||
material="PBR",
|
||||
quality_override=18000,
|
||||
tier="Regular",
|
||||
mesh_mode="Quad",
|
||||
TAPose = False,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
ta_pose: bool = False,
|
||||
):
|
||||
if images is None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) > 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
|
||||
path = "/proxy/rodin/api/v2/rodin"
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DGenerateRequest,
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
),
|
||||
request=Rodin3DGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
data=Rodin3DGenerateRequest(
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
material=material,
|
||||
quality_override=quality_override,
|
||||
mesh_mode=mesh_mode,
|
||||
TAPose=TAPose,
|
||||
TAPose=ta_pose,
|
||||
),
|
||||
files=[
|
||||
(
|
||||
@@ -159,11 +152,8 @@ async def create_generate_task(
|
||||
for image in images if image is not None
|
||||
],
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if hasattr(response, "error"):
|
||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||
logging.error(error_message)
|
||||
@@ -187,96 +177,68 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
|
||||
if not response.jobs:
|
||||
return None
|
||||
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
|
||||
return int((completed_count / len(response.jobs)) * 100)
|
||||
|
||||
async def poll_for_task_status(
|
||||
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> Rodin3DCheckStatusResponse:
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/rodin/api/v2/status",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DCheckStatusRequest,
|
||||
response_model=Rodin3DCheckStatusResponse,
|
||||
),
|
||||
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
||||
completed_statuses=["DONE"],
|
||||
failed_statuses=["FAILED"],
|
||||
status_extractor=check_rodin_status,
|
||||
poll_interval=3.0,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse:
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||
return await poll_operation.execute()
|
||||
|
||||
|
||||
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/rodin/api/v2/download",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DDownloadRequest,
|
||||
response_model=Rodin3DDownloadResponse,
|
||||
),
|
||||
request=Rodin3DDownloadRequest(task_uuid=uuid),
|
||||
auth_kwargs=auth_kwargs,
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"),
|
||||
response_model=Rodin3DCheckStatusResponse,
|
||||
data=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
||||
status_extractor=check_rodin_status,
|
||||
progress_extractor=extract_progress,
|
||||
)
|
||||
return await operation.execute()
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
|
||||
async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
return await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"),
|
||||
response_model=Rodin3DDownloadResponse,
|
||||
data=Rodin3DDownloadRequest(task_uuid=uuid),
|
||||
monitor_progress=False,
|
||||
)
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid: str):
|
||||
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in url_list.list:
|
||||
url = i.url
|
||||
file_name = i.name
|
||||
file_path = os.path.join(save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||
f.write(chunk)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logging.info(
|
||||
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
||||
file_path,
|
||||
max_retries,
|
||||
)
|
||||
for i in url_list.list:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
return model_file_path
|
||||
|
||||
|
||||
class Rodin3D_Regular(comfy_io.ComfyNode):
|
||||
class Rodin3D_Regular(IO.ComfyNode):
|
||||
"""Generate 3D Assets using Rodin API"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Regular",
|
||||
display_name="Rodin 3D Generate - Regular Generate",
|
||||
category="api node/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("Images"),
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -288,51 +250,48 @@ class Rodin3D_Regular(comfy_io.ComfyNode):
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
tier = "Regular"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return comfy_io.NodeOutput(model)
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class Rodin3D_Detail(comfy_io.ComfyNode):
|
||||
class Rodin3D_Detail(IO.ComfyNode):
|
||||
"""Generate 3D Assets using Rodin API"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Detail",
|
||||
display_name="Rodin 3D Generate - Detail Generate",
|
||||
category="api node/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("Images"),
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -344,51 +303,48 @@ class Rodin3D_Detail(comfy_io.ComfyNode):
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
tier = "Detail"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return comfy_io.NodeOutput(model)
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class Rodin3D_Smooth(comfy_io.ComfyNode):
|
||||
class Rodin3D_Smooth(IO.ComfyNode):
|
||||
"""Generate 3D Assets using Rodin API"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Smooth",
|
||||
display_name="Rodin 3D Generate - Smooth Generate",
|
||||
category="api node/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("Images"),
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -400,58 +356,54 @@ class Rodin3D_Smooth(comfy_io.ComfyNode):
|
||||
Seed,
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
) -> comfy_io.NodeOutput:
|
||||
tier = "Smooth"
|
||||
) -> IO.NodeOutput:
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
tier="Smooth",
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return comfy_io.NodeOutput(model)
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class Rodin3D_Sketch(comfy_io.ComfyNode):
|
||||
class Rodin3D_Sketch(IO.ComfyNode):
|
||||
"""Generate 3D Assets using Rodin API"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Sketch",
|
||||
display_name="Rodin 3D Generate - Sketch Generate",
|
||||
category="api node/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("Images"),
|
||||
comfy_io.Int.Input(
|
||||
IO.Image.Input("Images"),
|
||||
IO.Int.Input(
|
||||
"Seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=65535,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -461,68 +413,61 @@ class Rodin3D_Sketch(comfy_io.ComfyNode):
|
||||
cls,
|
||||
Images,
|
||||
Seed,
|
||||
) -> comfy_io.NodeOutput:
|
||||
tier = "Sketch"
|
||||
) -> IO.NodeOutput:
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
material_type = "PBR"
|
||||
quality_override = 18000
|
||||
mesh_mode = "Quad"
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=material_type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
material="PBR",
|
||||
quality_override=18000,
|
||||
tier="Sketch",
|
||||
mesh_mode="Quad",
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return comfy_io.NodeOutput(model)
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class Rodin3D_Gen2(comfy_io.ComfyNode):
|
||||
class Rodin3D_Gen2(IO.ComfyNode):
|
||||
"""Generate 3D Assets using Rodin API"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> comfy_io.Schema:
|
||||
return comfy_io.Schema(
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen2",
|
||||
display_name="Rodin 3D Generate - Gen-2 Generate",
|
||||
category="api node/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("Images"),
|
||||
comfy_io.Int.Input(
|
||||
IO.Image.Input("Images"),
|
||||
IO.Int.Input(
|
||||
"Seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=65535,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||
IO.Combo.Input(
|
||||
"Polygon_count",
|
||||
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
|
||||
default="500K-Triangle",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input("TAPose", default=False),
|
||||
IO.Boolean.Input("TAPose", default=False),
|
||||
],
|
||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -535,37 +480,33 @@ class Rodin3D_Gen2(comfy_io.ComfyNode):
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
TAPose,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
tier = "Gen-2"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
TAPose=TAPose,
|
||||
auth_kwargs=auth,
|
||||
ta_pose=TAPose,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return comfy_io.NodeOutput(model)
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class Rodin3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Rodin3D_Regular,
|
||||
Rodin3D_Detail,
|
||||
|
||||
@@ -11,7 +11,7 @@ User Guides:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
from typing import Union, Optional
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
@@ -21,7 +21,6 @@ from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||
RunwayTaskStatusEnum as TaskStatus,
|
||||
RunwayModelEnum as Model,
|
||||
RunwayDurationEnum as Duration,
|
||||
RunwayAspectRatioEnum as AspectRatio,
|
||||
@@ -33,23 +32,20 @@ from comfy_api_nodes.apis import (
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
],
|
||||
failed_statuses=[
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: response.status.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
).execute()
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
@@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
)
|
||||
|
||||
|
||||
async def generate_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
data=request,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
@@ -175,55 +136,55 @@ async def generate_video(
|
||||
return await download_url_to_video_output(video_url)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
||||
class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen3a",
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=Duration,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=RunwayGen3aAspectRatio,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -236,25 +197,21 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
return comfy_io.NodeOutput(
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -262,68 +219,62 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
||||
class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen4",
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=Duration,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=RunwayGen4TurboAspectRatio,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -336,25 +287,21 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
return comfy_io.NodeOutput(
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -362,76 +309,70 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
||||
class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="RunwayFirstLastFrameNode",
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="api node/video/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"start_frame",
|
||||
tooltip="Start frame to be used for the video",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"end_frame",
|
||||
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=Duration,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=RunwayGen3aAspectRatio,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -445,30 +386,26 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
|
||||
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return comfy_io.NodeOutput(
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -477,56 +414,50 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"),
|
||||
RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"),
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunwayTextToImageNode(comfy_io.ComfyNode):
|
||||
class RunwayTextToImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="RunwayTextToImageNode",
|
||||
display_name="Runway Text to Image",
|
||||
category="api node/image/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the generation",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"reference_image",
|
||||
tooltip="Optional reference image to guide the generation",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -537,63 +468,48 @@ class RunwayTextToImageNode(comfy_io.ComfyNode):
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Prepare reference images if provided
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||
|
||||
request = RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_IMAGE,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayTextToImageRequest,
|
||||
response_model=RunwayTextToImageResponse,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"),
|
||||
response_model=RunwayTextToImageResponse,
|
||||
data=RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
# Poll for completion
|
||||
final_response = await get_response(
|
||||
cls,
|
||||
initial_response.id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
RunwayFirstLastFrameNode,
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
@@ -601,5 +517,6 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayTextToImageNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> RunwayExtension:
|
||||
return RunwayExtension()
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import get_number_of_images
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
class Sora2GenerationRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
model: str = Field(...)
|
||||
@@ -31,27 +28,27 @@ class Sora2GenerationResponse(BaseModel):
|
||||
status: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class OpenAIVideoSora2(comfy_io.ComfyNode):
|
||||
class OpenAIVideoSora2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="OpenAIVideoSora2",
|
||||
display_name="OpenAI Sora - Video",
|
||||
category="api node/video/Sora",
|
||||
description="OpenAI video and audio generation.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["sora-2", "sora-2-pro"],
|
||||
default="sora-2",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Guiding text; may be empty if an input image is present.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=[
|
||||
"720x1280",
|
||||
@@ -61,35 +58,35 @@ class OpenAIVideoSora2(comfy_io.ComfyNode):
|
||||
],
|
||||
default="1280x720",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"duration",
|
||||
options=[4, 8, 12],
|
||||
default=8,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
optional=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -111,61 +108,40 @@ class OpenAIVideoSora2(comfy_io.ComfyNode):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Currently only one input image is supported.")
|
||||
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload = Sora2GenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
seconds=str(duration),
|
||||
size=size,
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/openai/v1/videos",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Sora2GenerationRequest,
|
||||
response_model=Sora2GenerationResponse
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"),
|
||||
data=Sora2GenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
seconds=str(duration),
|
||||
size=size,
|
||||
),
|
||||
request=payload,
|
||||
files=files_input,
|
||||
auth_kwargs=auth,
|
||||
response_model=Sora2GenerationResponse,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
initial_response = await initial_operation.execute()
|
||||
if initial_response.error:
|
||||
raise Exception(initial_response.error.message)
|
||||
raise Exception(initial_response.error["message"])
|
||||
|
||||
model_time_multiplier = 1 if model == "sora-2" else 2
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/openai/v1/videos/{initial_response.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=Sora2GenerationResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"),
|
||||
response_model=Sora2GenerationResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
auth_kwargs=auth,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=45 * (duration / 4) * model_time_multiplier,
|
||||
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
|
||||
)
|
||||
await poll_operation.execute()
|
||||
return comfy_io.NodeOutput(
|
||||
await download_url_to_video_output(
|
||||
f"/proxy/openai/v1/videos/{initial_response.id}/content",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls),
|
||||
)
|
||||
|
||||
|
||||
class OpenAISoraExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
OpenAIVideoSora2,
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
|
||||
from comfy_api.latest import ComfyExtension, Input, IO
|
||||
from comfy_api_nodes.apis.stability_api import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
@@ -20,21 +20,17 @@ from comfy_api_nodes.apis.stability_api import (
|
||||
StabilityAudioInpaintRequest,
|
||||
StabilityAudioResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api_nodes.util import (
|
||||
validate_audio_duration,
|
||||
validate_string,
|
||||
audio_input_to_mp3,
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import validate_audio_duration
|
||||
|
||||
import torch
|
||||
import base64
|
||||
@@ -56,20 +52,20 @@ def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||
return StabilityPollStatus.in_progress
|
||||
|
||||
|
||||
class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
@@ -80,39 +76,39 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||
"would convey a sky that was blue and green, but more green than blue.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
@@ -123,12 +119,12 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -143,7 +139,7 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStableUltraRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStableUltraRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStableUltraRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
@@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
@@ -193,44 +179,44 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Stability_SD3_5_Model,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"cfg_scale",
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
@@ -238,28 +224,28 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
step=0.1,
|
||||
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
@@ -270,12 +256,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -292,7 +278,7 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
@@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStable3_5Request,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStable3_5Request(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStable3_5Request(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
@@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
@@ -348,30 +324,30 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image"),
|
||||
comfy_io.String.Input(
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.35,
|
||||
min=0.2,
|
||||
@@ -379,17 +355,17 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
@@ -398,12 +374,12 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -416,7 +392,7 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
creativity: float,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleConservativeRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityUpscaleConservativeRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityUpscaleConservativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
@@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
@@ -457,30 +423,30 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image"),
|
||||
comfy_io.String.Input(
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.3,
|
||||
min=0.1,
|
||||
@@ -488,22 +454,22 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
@@ -512,12 +478,12 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -531,7 +497,7 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
@@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleCreativeRequest,
|
||||
response_model=StabilityAsyncResponse,
|
||||
),
|
||||
request=StabilityUpscaleCreativeRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||
response_model=StabilityAsyncResponse,
|
||||
data=StabilityUpscaleCreativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
@@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityResultsGetResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||
response_model=StabilityResultsGetResponse,
|
||||
poll_interval=3,
|
||||
completed_statuses=[StabilityPollStatus.finished],
|
||||
failed_statuses=[StabilityPollStatus.failed],
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
@@ -591,61 +539,50 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
||||
image_data = base64.b64decode(response_poll.result)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleFastNode(comfy_io.ComfyNode):
|
||||
class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
"""
|
||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="api node/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Image.Input("image"),
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput:
|
||||
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
@@ -653,26 +590,26 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode):
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityTextToAudio(comfy_io.ComfyNode):
|
||||
class StabilityTextToAudio(IO.ComfyNode):
|
||||
"""Generates high-quality music and sound effects from text descriptions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||
comfy_io.Int.Input(
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
@@ -681,18 +618,18 @@ class StabilityTextToAudio(comfy_io.ComfyNode):
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
@@ -703,58 +640,50 @@ class StabilityTextToAudio(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Audio.Output(),
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput:
|
||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityTextToAudioRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs= {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioToAudio(comfy_io.ComfyNode):
|
||||
class StabilityAudioToAudio(IO.ComfyNode):
|
||||
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
comfy_io.Int.Input(
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
@@ -763,18 +692,18 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
@@ -783,24 +712,24 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=1,
|
||||
min=0.01,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=comfy_io.NumberDisplay.slider,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Audio.Output(),
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -808,51 +737,43 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
payload = StabilityAudioToAudioRequest(
|
||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityAudioToAudioRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
auth_kwargs= {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
class StabilityAudioInpaint(IO.ComfyNode):
|
||||
"""Transforms part of existing audio sample using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="api node/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
comfy_io.Int.Input(
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
@@ -861,18 +782,18 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
@@ -881,7 +802,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"mask_start",
|
||||
default=30,
|
||||
min=0,
|
||||
@@ -889,7 +810,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
step=1,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"mask_end",
|
||||
default=190,
|
||||
min=0,
|
||||
@@ -899,12 +820,12 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Audio.Output(),
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -920,7 +841,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
steps: int,
|
||||
mask_start: int,
|
||||
mask_end: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
if mask_end <= mask_start:
|
||||
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
||||
@@ -935,30 +856,22 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
mask_start=mask_start,
|
||||
mask_end=mask_end,
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityAudioInpaintRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
StabilityStableImageUltraNode,
|
||||
StabilityStableImageSD_3_5Node,
|
||||
|
||||
421
comfy_api_nodes/nodes_topaz.py
Normal file
421
comfy_api_nodes/nodes_topaz.py
Normal file
@@ -0,0 +1,421 @@
|
||||
import builtins
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis import topaz_api
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_fs_object_size,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
)
|
||||
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
}
|
||||
UPSCALER_VALUES_MAP = {
|
||||
"FullHD (1080p)": 1920,
|
||||
"4K (2160p)": 3840,
|
||||
}
|
||||
|
||||
|
||||
class TopazImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazImageEnhance",
|
||||
display_name="Topaz Image Enhance",
|
||||
category="api node/image/Topaz",
|
||||
description="Industry-standard upscaling and image enhancement.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["Reimagine"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional text prompt for creative upscaling guidance.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"subject_detection",
|
||||
options=["All", "Foreground", "Background"],
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"face_enhancement",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Enhance faces (if present) during processing.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"face_enhancement_creativity",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Set the creativity level for face enhancement.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"face_enhancement_strength",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Controls how sharp enhanced faces are relative to the background.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"crop_to_fill",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="By default, the image is letterboxed when the output aspect ratio differs. "
|
||||
"Enable to crop the image to fill the output dimensions.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"output_width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=32000,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"output_height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=32000,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Zero value means to output in the same height as original or output width.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"creativity",
|
||||
default=3,
|
||||
min=1,
|
||||
max=9,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"face_preservation",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Preserve subjects' facial identity.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"color_preservation",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Preserve the original colors.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
prompt: str = "",
|
||||
subject_detection: str = "All",
|
||||
face_enhancement: bool = True,
|
||||
face_enhancement_creativity: float = 1.0,
|
||||
face_enhancement_strength: float = 0.8,
|
||||
crop_to_fill: bool = False,
|
||||
output_width: int = 0,
|
||||
output_height: int = 0,
|
||||
creativity: int = 3,
|
||||
face_preservation: bool = True,
|
||||
color_preservation: bool = True,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
|
||||
response_model=topaz_api.ImageAsyncTaskResponse,
|
||||
data=topaz_api.ImageEnhanceRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
subject_detection=subject_detection,
|
||||
face_enhancement=face_enhancement,
|
||||
face_enhancement_creativity=face_enhancement_creativity,
|
||||
face_enhancement_strength=face_enhancement_strength,
|
||||
crop_to_fill=crop_to_fill,
|
||||
output_width=output_width if output_width else None,
|
||||
output_height=output_height if output_height else None,
|
||||
creativity=creativity,
|
||||
face_preservation=str(face_preservation).lower(),
|
||||
color_preservation=str(color_preservation).lower(),
|
||||
source_url=download_url[0],
|
||||
output_format="png",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: x.credits * 0.08,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
estimated_duration=60,
|
||||
)
|
||||
|
||||
results = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageDownloadResponse,
|
||||
monitor_progress=False,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(results.download_url))
|
||||
|
||||
|
||||
class TopazVideoEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
|
||||
IO.Combo.Input(
|
||||
"upscaler_creativity",
|
||||
options=["low", "middle", "high"],
|
||||
default="low",
|
||||
tooltip="Creativity level (applies only to Starlight (Astra) Creative).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("interpolation_enabled", default=False, optional=True),
|
||||
IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True),
|
||||
IO.Int.Input(
|
||||
"interpolation_slowmo",
|
||||
default=1,
|
||||
min=1,
|
||||
max=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Slow-motion factor applied to the input video. "
|
||||
"For example, 2 makes the output twice as slow and doubles the duration.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"interpolation_frame_rate",
|
||||
default=60,
|
||||
min=15,
|
||||
max=240,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Output frame rate.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"interpolation_duplicate",
|
||||
default=False,
|
||||
tooltip="Analyze the input for duplicate frames and remove them.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"interpolation_duplicate_threshold",
|
||||
default=0.01,
|
||||
min=0.001,
|
||||
max=0.1,
|
||||
step=0.001,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Detection sensitivity for duplicate frames.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"dynamic_compression_level",
|
||||
options=["Low", "Mid", "High"],
|
||||
default="Low",
|
||||
tooltip="CQP level.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
upscaler_enabled: bool,
|
||||
upscaler_model: str,
|
||||
upscaler_resolution: str,
|
||||
upscaler_creativity: str = "low",
|
||||
interpolation_enabled: bool = False,
|
||||
interpolation_model: str = "apo-8",
|
||||
interpolation_slowmo: int = 1,
|
||||
interpolation_frame_rate: int = 60,
|
||||
interpolation_duplicate: bool = False,
|
||||
interpolation_duplicate_threshold: float = 0.01,
|
||||
dynamic_compression_level: str = "Low",
|
||||
) -> IO.NodeOutput:
|
||||
if upscaler_enabled is False and interpolation_enabled is False:
|
||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||
src_width, src_height = video.get_dimensions()
|
||||
video_components = video.get_components()
|
||||
src_frame_rate = int(video_components.frame_rate)
|
||||
duration_sec = video.get_duration()
|
||||
estimated_frames = int(duration_sec * src_frame_rate)
|
||||
validate_container_format_is_mp4(video)
|
||||
src_video_stream = video.get_stream_source()
|
||||
target_width = src_width
|
||||
target_height = src_height
|
||||
target_frame_rate = src_frame_rate
|
||||
filters = []
|
||||
if upscaler_enabled:
|
||||
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
filters.append(
|
||||
topaz_api.VideoEnhancementFilter(
|
||||
model=UPSCALER_MODELS_MAP[upscaler_model],
|
||||
creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
),
|
||||
)
|
||||
if interpolation_enabled:
|
||||
target_frame_rate = interpolation_frame_rate
|
||||
filters.append(
|
||||
topaz_api.VideoFrameInterpolationFilter(
|
||||
model=interpolation_model,
|
||||
slowmo=interpolation_slowmo,
|
||||
fps=interpolation_frame_rate,
|
||||
duplicate=interpolation_duplicate,
|
||||
duplicate_threshold=interpolation_duplicate_threshold,
|
||||
),
|
||||
)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||
response_model=topaz_api.CreateVideoResponse,
|
||||
data=topaz_api.CreateVideoRequest(
|
||||
source=topaz_api.CreateCreateVideoRequestSource(
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=estimated_frames,
|
||||
frameRate=src_frame_rate,
|
||||
resolution=topaz_api.Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
filters=filters,
|
||||
output=topaz_api.OutputInformationVideo(
|
||||
resolution=topaz_api.Resolution(width=target_width, height=target_height),
|
||||
frameRate=target_frame_rate,
|
||||
audioCodec="AAC",
|
||||
audioTransfer="Copy",
|
||||
dynamicCompressionLevel=dynamic_compression_level,
|
||||
),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
upload_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoAcceptResponse,
|
||||
wait_label="Preparing upload",
|
||||
final_label_on_success="Upload started",
|
||||
)
|
||||
if len(upload_res.urls) > 1:
|
||||
raise NotImplementedError(
|
||||
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
|
||||
)
|
||||
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
|
||||
if isinstance(src_video_stream, BytesIO):
|
||||
src_video_stream.seek(0)
|
||||
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
else:
|
||||
with builtins.open(src_video_stream, "rb") as video_file:
|
||||
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoCompleteUploadResponse,
|
||||
data=topaz_api.VideoCompleteUploadRequest(
|
||||
uploadResults=[
|
||||
topaz_api.VideoCompleteUploadRequestPart(
|
||||
partNum=1,
|
||||
eTag=upload_etag,
|
||||
),
|
||||
],
|
||||
),
|
||||
wait_label="Finalizing upload",
|
||||
final_label_on_success="Upload completed",
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||
response_model=topaz_api.VideoStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TopazImageEnhance,
|
||||
TopazVideoEnhance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> TopazExtension:
|
||||
return TopazExtension()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,57 +1,35 @@
|
||||
import logging
|
||||
import base64
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
MODELS_MAP = {
|
||||
"veo-2.0-generate-001": "veo-2.0-generate-001",
|
||||
"veo-3.1-generate": "veo-3.1-generate-preview",
|
||||
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
|
||||
"veo-3.0-generate-001": "veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
||||
}
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
else:
|
||||
return None
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return str(video.gcsUri)
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
|
||||
@@ -61,71 +39,71 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="VeoVideoGenerationNode",
|
||||
display_name="Google Veo 2 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration_seconds",
|
||||
default=5,
|
||||
min=5,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["veo-2.0-generate-001"],
|
||||
default="veo-2.0-generate-001",
|
||||
@@ -134,12 +112,12 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -158,21 +136,17 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
model="veo-2.0-generate-001",
|
||||
generate_audio=False,
|
||||
):
|
||||
model = MODELS_MAP[model]
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
|
||||
instance = {
|
||||
"prompt": prompt
|
||||
}
|
||||
instance = {"prompt": prompt}
|
||||
|
||||
# Add image if provided
|
||||
if image is not None:
|
||||
image_base64 = convert_image_to_base64(image)
|
||||
image_base64 = tensor_to_base64_string(image)
|
||||
if image_base64:
|
||||
instance["image"] = {
|
||||
"bytesBase64Encoded": image_base64,
|
||||
"mimeType": "image/png"
|
||||
}
|
||||
instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
@@ -190,119 +164,77 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
# Only add generateAudio for Veo 3 models
|
||||
if "veo-3.0" in model:
|
||||
if model.find("veo-2.0") == -1:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=f"/proxy/veo/{model}/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=VeoGenVidRequest,
|
||||
response_model=VeoGenVidResponse
|
||||
),
|
||||
request=VeoGenVidRequest(
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||
response_model=VeoGenVidResponse,
|
||||
data=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
parameters=parameters,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info("Veo generation started with operation name: %s", operation_name)
|
||||
|
||||
# Define status extractor function
|
||||
def status_extractor(response):
|
||||
# Only return "completed" if the operation is done, regardless of success or failure
|
||||
# We'll check for errors after polling completes
|
||||
return "completed" if response.done else "pending"
|
||||
|
||||
# Define progress extractor function
|
||||
def progress_extractor(response):
|
||||
# Could be enhanced if the API provides progress information
|
||||
return None
|
||||
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/veo/{model}/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=VeoGenVidPollRequest,
|
||||
response_model=VeoGenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
poll_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||
response_model=VeoGenVidPollResponse,
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
data=VeoGenVidPollRequest(
|
||||
operationName=initial_response.name,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = await poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
if hasattr(poll_response, 'error') and poll_response.error:
|
||||
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
if poll_response.error:
|
||||
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||
|
||||
# Check for RAI filtered content
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
|
||||
poll_response.response.raiMediaFilteredCount > 0):
|
||||
if (
|
||||
hasattr(poll_response.response, "raiMediaFilteredCount")
|
||||
and poll_response.response.raiMediaFilteredCount > 0
|
||||
):
|
||||
|
||||
# Extract reason message if available
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
|
||||
poll_response.response.raiMediaFilteredReasons):
|
||||
if (
|
||||
hasattr(poll_response.response, "raiMediaFilteredReasons")
|
||||
and poll_response.response.raiMediaFilteredReasons
|
||||
):
|
||||
reason = poll_response.response.raiMediaFilteredReasons[0]
|
||||
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
else:
|
||||
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
|
||||
# Decode base64 string to bytes
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(video.gcsUri) as video_response:
|
||||
video_data = await video_response.content.read()
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
|
||||
if not video_data:
|
||||
raise Exception("No video data was returned")
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
@@ -319,78 +251,83 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="Veo3VideoGenerationNode",
|
||||
display_name="Google Veo 3 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration_seconds",
|
||||
default=8,
|
||||
min=8,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
options=[
|
||||
"veo-3.1-generate",
|
||||
"veo-3.1-fast-generate",
|
||||
"veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001",
|
||||
],
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
|
||||
@@ -398,12 +335,12 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -411,11 +348,12 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
|
||||
class VeoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
VeoVideoGenerationNode,
|
||||
Veo3VideoGenerationNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> VeoExtension:
|
||||
return VeoExtension()
|
||||
|
||||
@@ -1,27 +1,23 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Literal, TypeVar
|
||||
from typing_extensions import override
|
||||
from typing import Literal, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_aspect_ratio_closeness,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio_range,
|
||||
get_number_of_images,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_images_aspect_ratio_closeness,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
|
||||
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
||||
@@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class VideoModelName(str, Enum):
|
||||
vidu_q1 = 'viduq1'
|
||||
vidu_q1 = "viduq1"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel):
|
||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
created = "created"
|
||||
queueing = "queueing"
|
||||
processing = "processing"
|
||||
success = "success"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: TaskStatus = Field(...)
|
||||
state: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: Optional[int] = Field(None, description="Error code")
|
||||
|
||||
@@ -85,32 +74,11 @@ class TaskResult(BaseModel):
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: TaskStatus = Field(...)
|
||||
state: str = Field(...)
|
||||
err_code: Optional[str] = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[TaskStatus.success.value],
|
||||
failed_statuses=[TaskStatus.failed.value],
|
||||
status_extractor=lambda response: response.state.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
if response.creations:
|
||||
return response.creations[0].url
|
||||
@@ -127,97 +95,87 @@ def get_video_from_response(response) -> TaskResult:
|
||||
|
||||
|
||||
async def execute_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
vidu_endpoint: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
payload: TaskCreationRequest,
|
||||
estimated_duration: int,
|
||||
node_id: str,
|
||||
) -> R:
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=vidu_endpoint,
|
||||
method=HttpMethod.POST,
|
||||
request_model=TaskCreationRequest,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
if response.state == TaskStatus.failed:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=payload,
|
||||
)
|
||||
if response.state == "failed":
|
||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=VIDU_GET_GENERATION_STATUS % response.task_id,
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
||||
class ViduTextToVideoNode(comfy_io.ComfyNode):
|
||||
class ViduTextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from text prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=AspectRatio,
|
||||
default=AspectRatio.r_16_9,
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=Resolution,
|
||||
default=Resolution.r_1080p,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=MovementAmplitude,
|
||||
default=MovementAmplitude.auto,
|
||||
@@ -226,12 +184,12 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -246,7 +204,7 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
payload = TaskCreationRequest(
|
||||
@@ -258,70 +216,66 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduImageToVideoNode(comfy_io.ComfyNode):
|
||||
class ViduImageToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="ViduImageToVideoNode",
|
||||
display_name="Vidu Image To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="A textual description for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=Resolution,
|
||||
default=Resolution.r_1080p,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=MovementAmplitude,
|
||||
default=MovementAmplitude.auto.value,
|
||||
@@ -330,12 +284,12 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -350,10 +304,10 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) > 1:
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
@@ -362,81 +316,77 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
||||
class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from multiple images and prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=VideoModelName,
|
||||
default=VideoModelName.vidu_q1,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=AspectRatio,
|
||||
default=AspectRatio.r_16_9,
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
@@ -445,12 +395,12 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -466,14 +416,14 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
a = get_number_of_images(images)
|
||||
if a > 7:
|
||||
raise ValueError("Too many images, maximum allowed is 7.")
|
||||
for image in images:
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
@@ -484,79 +434,75 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
max_images=7,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
||||
class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="ViduStartEndToVideoNode",
|
||||
display_name="Vidu Start End To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="Start frame",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"end_frame",
|
||||
tooltip="End frame",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
@@ -565,12 +511,12 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -586,8 +532,8 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
) -> IO.NodeOutput:
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
@@ -596,21 +542,17 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = [
|
||||
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
]
|
||||
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ViduTextToVideoNode,
|
||||
ViduImageToVideoNode,
|
||||
@@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension):
|
||||
ViduStartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ViduExtension:
|
||||
return ViduExtension()
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
import re
|
||||
from typing import Optional, Type, Union
|
||||
from typing_extensions import override
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
R,
|
||||
T,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
audio_to_base64_string,
|
||||
validate_audio_duration,
|
||||
)
|
||||
|
||||
|
||||
class Text2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
@@ -146,84 +142,38 @@ class VideoTaskStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
|
||||
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
|
||||
|
||||
|
||||
async def process_task(
|
||||
auth_kwargs: dict[str, str],
|
||||
url: str,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
payload: Union[
|
||||
Text2ImageTaskCreationRequest,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
],
|
||||
node_id: str,
|
||||
estimated_duration: int,
|
||||
poll_interval: int,
|
||||
) -> Type[R]:
|
||||
initial_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=url,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_model,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=response_model,
|
||||
),
|
||||
completed_statuses=["SUCCEEDED"],
|
||||
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
poll_interval=poll_interval,
|
||||
node_id=node_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
|
||||
class WanTextToImageApi(comfy_io.ComfyNode):
|
||||
class WanTextToImageApi(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="WanTextToImageApi",
|
||||
display_name="Wan Text to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates image based on text prompt.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-t2i-preview"],
|
||||
default="wan2.5-t2i-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
default=1024,
|
||||
min=768,
|
||||
@@ -231,7 +181,7 @@ class WanTextToImageApi(comfy_io.ComfyNode):
|
||||
step=32,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"height",
|
||||
default=1024,
|
||||
min=768,
|
||||
@@ -239,37 +189,37 @@ class WanTextToImageApi(comfy_io.ComfyNode):
|
||||
step=32,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"prompt_extend",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -286,59 +236,61 @@ class WanTextToImageApi(comfy_io.ComfyNode):
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
):
|
||||
payload = Text2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||
parameters=Txt2ImageParametersField(
|
||||
size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Text2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||
parameters=Txt2ImageParametersField(
|
||||
size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
|
||||
request_model=Text2ImageTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=ImageTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=9,
|
||||
poll_interval=3,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
|
||||
|
||||
class WanImageToImageApi(comfy_io.ComfyNode):
|
||||
class WanImageToImageApi(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="WanImageToImageApi",
|
||||
display_name="Wan Image to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-i2i-preview"],
|
||||
default="wan2.5-i2i-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
@@ -346,7 +298,7 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
# redo this later as an optional combo of recommended resolutions
|
||||
# comfy_io.Int.Input(
|
||||
# IO.Int.Input(
|
||||
# "width",
|
||||
# default=1280,
|
||||
# min=384,
|
||||
@@ -354,7 +306,7 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
||||
# step=16,
|
||||
# optional=True,
|
||||
# ),
|
||||
# comfy_io.Int.Input(
|
||||
# IO.Int.Input(
|
||||
# "height",
|
||||
# default=1280,
|
||||
# min=384,
|
||||
@@ -362,31 +314,31 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
||||
# step=16,
|
||||
# optional=True,
|
||||
# ),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -408,61 +360,63 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
||||
images = []
|
||||
for i in image:
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
|
||||
payload = Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||
parameters=Image2ImageParametersField(
|
||||
# size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||
parameters=Image2ImageParametersField(
|
||||
# size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
|
||||
request_model=Image2ImageTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=ImageTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=42,
|
||||
poll_interval=3,
|
||||
poll_interval=4,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
|
||||
|
||||
class WanTextToVideoApi(comfy_io.ComfyNode):
|
||||
class WanTextToVideoApi(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="WanTextToVideoApi",
|
||||
display_name="Wan Text to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Generates video based on text prompt.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-t2v-preview"],
|
||||
default="wan2.5-t2v-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=[
|
||||
"480p: 1:1 (624x624)",
|
||||
@@ -482,58 +436,58 @@ class WanTextToVideoApi(comfy_io.ComfyNode):
|
||||
default="480p: 1:1 (624x624)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
step=5,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Available durations: 5 and 10 seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Audio.Input(
|
||||
IO.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="If there is no audio input, generate audio automatically.",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"prompt_extend",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -557,66 +511,69 @@ class WanTextToVideoApi(comfy_io.ComfyNode):
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
payload = Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||
parameters=Text2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||
parameters=Text2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Text2VideoTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=VideoTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanImageToVideoApi(comfy_io.ComfyNode):
|
||||
class WanImageToVideoApi(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="WanImageToVideoApi",
|
||||
display_name="Wan Image to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Generates video based on the first frame and text prompt.",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-i2v-preview"],
|
||||
default="wan2.5-i2v-preview",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"480P",
|
||||
@@ -626,58 +583,58 @@ class WanImageToVideoApi(comfy_io.ComfyNode):
|
||||
default="480P",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
step=5,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Available durations: 5 and 10 seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Audio.Input(
|
||||
IO.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="If there is no audio input, generate audio automatically.",
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"prompt_extend",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@@ -699,44 +656,46 @@ class WanImageToVideoApi(comfy_io.ComfyNode):
|
||||
):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||
audio_url = None
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
payload = Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||
),
|
||||
parameters=Image2VideoParametersField(
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||
),
|
||||
parameters=Image2VideoParametersField(
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=VideoTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
WanTextToImageApi,
|
||||
WanImageToImageApi,
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
from ._helpers import get_fs_object_size
|
||||
from .client import (
|
||||
ApiEndpoint,
|
||||
poll_op,
|
||||
poll_op_raw,
|
||||
sync_op,
|
||||
sync_op_raw,
|
||||
)
|
||||
from .conversions import (
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
downscale_image_tensor,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from .download_helpers import (
|
||||
download_url_as_bytesio,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from .upload_helpers import (
|
||||
upload_audio_to_comfyapi,
|
||||
upload_file_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
from .validation_utils import (
|
||||
get_number_of_images,
|
||||
validate_aspect_ratio_string,
|
||||
validate_audio_duration,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_images_aspect_ratio_closeness,
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# API client
|
||||
"ApiEndpoint",
|
||||
"poll_op",
|
||||
"poll_op_raw",
|
||||
"sync_op",
|
||||
"sync_op_raw",
|
||||
# Upload helpers
|
||||
"upload_audio_to_comfyapi",
|
||||
"upload_file_to_comfyapi",
|
||||
"upload_images_to_comfyapi",
|
||||
"upload_video_to_comfyapi",
|
||||
# Download helpers
|
||||
"download_url_as_bytesio",
|
||||
"download_url_to_bytesio",
|
||||
"download_url_to_image_tensor",
|
||||
"download_url_to_video_output",
|
||||
# Conversions
|
||||
"audio_bytes_to_audio_input",
|
||||
"audio_input_to_mp3",
|
||||
"audio_to_base64_string",
|
||||
"bytesio_to_image_tensor",
|
||||
"downscale_image_tensor",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"resize_mask_to_image",
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
"get_number_of_images",
|
||||
"validate_aspect_ratio_string",
|
||||
"validate_audio_duration",
|
||||
"validate_container_format_is_mp4",
|
||||
"validate_image_aspect_ratio",
|
||||
"validate_image_dimensions",
|
||||
"validate_images_aspect_ratio_closeness",
|
||||
"validate_string",
|
||||
"validate_video_dimensions",
|
||||
"validate_video_duration",
|
||||
# Misc functions
|
||||
"get_fs_object_size",
|
||||
]
|
||||
|
||||
71
comfy_api_nodes/util/_helpers.py
Normal file
71
comfy_api_nodes/util/_helpers.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
from .common_exceptions import ProcessingInterrupted
|
||||
|
||||
|
||||
def is_processing_interrupted() -> bool:
|
||||
"""Return True if user/runtime requested interruption."""
|
||||
return processing_interrupted()
|
||||
|
||||
|
||||
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
|
||||
return node_cls.hidden.unique_id
|
||||
|
||||
|
||||
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
if node_cls.hidden.auth_token_comfy_org:
|
||||
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
|
||||
if node_cls.hidden.api_key_comfy_org:
|
||||
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
|
||||
return {}
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||
|
||||
|
||||
async def sleep_with_interrupt(
|
||||
seconds: float,
|
||||
node_cls: Optional[type[IO.ComfyNode]],
|
||||
label: Optional[str] = None,
|
||||
start_ts: Optional[float] = None,
|
||||
estimated_total: Optional[int] = None,
|
||||
*,
|
||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
||||
):
|
||||
"""
|
||||
Sleep in 1s slices while:
|
||||
- Checking for interruption (raises ProcessingInterrupted).
|
||||
- Optionally emitting time progress via display_callback (if provided).
|
||||
"""
|
||||
end = time.monotonic() + seconds
|
||||
while True:
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
now = time.monotonic()
|
||||
if start_ts is not None and label and display_callback:
|
||||
with contextlib.suppress(Exception):
|
||||
display_callback(node_cls, label, int(now - start_ts), estimated_total)
|
||||
if now >= end:
|
||||
break
|
||||
await asyncio.sleep(min(1.0, end - now))
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
946
comfy_api_nodes/util/client.py
Normal file
946
comfy_api_nodes/util/client.py
Normal file
@@ -0,0 +1,946 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from server import PromptServer
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
)
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
class ApiEndpoint:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||
*,
|
||||
query_params: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
):
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.query_params = query_params or {}
|
||||
self.headers = headers or {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RequestConfig:
|
||||
node_cls: type[IO.ComfyNode]
|
||||
endpoint: ApiEndpoint
|
||||
timeout: float
|
||||
content_type: str
|
||||
data: Optional[dict[str, Any]]
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||||
multipart_parser: Optional[Callable]
|
||||
max_retries: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
monitor_progress: bool = True
|
||||
estimated_total: Optional[int] = None
|
||||
final_label_on_success: Optional[str] = "Completed"
|
||||
progress_origin_ts: Optional[float] = None
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PollUIState:
|
||||
started: float
|
||||
status_label: str = "Queued"
|
||||
is_queued: bool = True
|
||||
price: Optional[float] = None
|
||||
estimated_duration: Optional[int] = None
|
||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
endpoint,
|
||||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||
data=data,
|
||||
files=files,
|
||||
content_type=content_type,
|
||||
timeout=timeout,
|
||||
multipart_parser=multipart_parser,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
retry_backoff=retry_backoff,
|
||||
wait_label=wait_label,
|
||||
estimated_duration=estimated_duration,
|
||||
as_binary=False,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
async def poll_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
cls,
|
||||
poll_endpoint=poll_endpoint,
|
||||
status_extractor=_wrap_model_extractor(response_model, status_extractor),
|
||||
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
|
||||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||
completed_statuses=completed_statuses,
|
||||
failed_statuses=failed_statuses,
|
||||
queued_statuses=queued_statuses,
|
||||
data=data,
|
||||
poll_interval=poll_interval,
|
||||
max_poll_attempts=max_poll_attempts,
|
||||
timeout_per_poll=timeout_per_poll,
|
||||
max_retries_per_poll=max_retries_per_poll,
|
||||
retry_delay_per_poll=retry_delay_per_poll,
|
||||
retry_backoff_per_poll=retry_backoff_per_poll,
|
||||
estimated_duration=estimated_duration,
|
||||
cancel_endpoint=cancel_endpoint,
|
||||
cancel_timeout=cancel_timeout,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
async def sync_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
as_binary: bool = False,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> Union[dict[str, Any], bytes]:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
- If as_binary=True: returns bytes.
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
for k, v in list(data.items()):
|
||||
if isinstance(v, Enum):
|
||||
data[k] = v.value
|
||||
cfg = _RequestConfig(
|
||||
node_cls=cls,
|
||||
endpoint=endpoint,
|
||||
timeout=timeout,
|
||||
content_type=content_type,
|
||||
data=data,
|
||||
files=files,
|
||||
multipart_parser=multipart_parser,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
retry_backoff=retry_backoff,
|
||||
wait_label=wait_label,
|
||||
monitor_progress=monitor_progress,
|
||||
estimated_total=estimated_duration,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
|
||||
async def poll_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
|
||||
|
||||
Uses default complete, failed and queued states assumption.
|
||||
|
||||
Returns the final JSON response from the poll endpoint.
|
||||
"""
|
||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||||
started = time.monotonic()
|
||||
consumed_attempts = 0 # counts only non-queued polls
|
||||
|
||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||
last_progress: Optional[int] = None
|
||||
|
||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||
stop_ticker = asyncio.Event()
|
||||
|
||||
async def _ticker():
|
||||
"""Emit a UI update every second while polling is in progress."""
|
||||
try:
|
||||
while not stop_ticker.is_set():
|
||||
if is_processing_interrupted():
|
||||
break
|
||||
now = time.monotonic()
|
||||
proc_elapsed = state.base_processing_elapsed + (
|
||||
(now - state.active_since) if state.active_since is not None else 0.0
|
||||
)
|
||||
_display_time_progress(
|
||||
cls,
|
||||
status=state.status_label,
|
||||
elapsed_seconds=int(now - state.started),
|
||||
estimated_total=state.estimated_duration,
|
||||
price=state.price,
|
||||
is_queued=state.is_queued,
|
||||
processing_elapsed_seconds=int(proc_elapsed),
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception as exc:
|
||||
logging.debug("Polling ticker exited: %s", exc)
|
||||
|
||||
ticker_task = asyncio.create_task(_ticker())
|
||||
try:
|
||||
while consumed_attempts < max_poll_attempts:
|
||||
try:
|
||||
resp_json = await sync_op_raw(
|
||||
cls,
|
||||
poll_endpoint,
|
||||
data=data,
|
||||
timeout=timeout_per_poll,
|
||||
max_retries=max_retries_per_poll,
|
||||
retry_delay=retry_delay_per_poll,
|
||||
retry_backoff=retry_backoff_per_poll,
|
||||
wait_label="Checking",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
if not isinstance(resp_json, dict):
|
||||
raise Exception("Polling endpoint returned non-JSON response.")
|
||||
except ProcessingInterrupted:
|
||||
if cancel_endpoint:
|
||||
with contextlib.suppress(Exception):
|
||||
await sync_op_raw(
|
||||
cls,
|
||||
cancel_endpoint,
|
||||
timeout=cancel_timeout,
|
||||
max_retries=0,
|
||||
wait_label="Cancelling task",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
status = _normalize_status_value(status_extractor(resp_json))
|
||||
except Exception as e:
|
||||
logging.error("Status extraction failed: %s", e)
|
||||
status = None
|
||||
|
||||
if price_extractor:
|
||||
new_price = price_extractor(resp_json)
|
||||
if new_price is not None:
|
||||
state.price = new_price
|
||||
|
||||
if progress_extractor:
|
||||
new_progress = progress_extractor(resp_json)
|
||||
if new_progress is not None and last_progress != new_progress:
|
||||
progress_bar.update_absolute(new_progress, total=100)
|
||||
last_progress = new_progress
|
||||
|
||||
now_ts = time.monotonic()
|
||||
is_queued = status in queued_states
|
||||
|
||||
if is_queued:
|
||||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
else:
|
||||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||||
state.active_since = now_ts
|
||||
|
||||
state.is_queued = is_queued
|
||||
state.status_label = status or ("Queued" if is_queued else "Processing")
|
||||
if status in completed_states:
|
||||
if state.active_since is not None:
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
stop_ticker.set()
|
||||
with contextlib.suppress(Exception):
|
||||
await ticker_task
|
||||
|
||||
if progress_bar and last_progress != 100:
|
||||
progress_bar.update_absolute(100, total=100)
|
||||
|
||||
_display_time_progress(
|
||||
cls,
|
||||
status=status if status else "Completed",
|
||||
elapsed_seconds=int(now_ts - started),
|
||||
estimated_total=estimated_duration,
|
||||
price=state.price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||||
)
|
||||
return resp_json
|
||||
|
||||
if status in failed_states:
|
||||
msg = f"Task failed: {json.dumps(resp_json)}"
|
||||
logging.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
try:
|
||||
await sleep_with_interrupt(poll_interval, cls, None, None, None)
|
||||
except ProcessingInterrupted:
|
||||
if cancel_endpoint:
|
||||
with contextlib.suppress(Exception):
|
||||
await sync_op_raw(
|
||||
cls,
|
||||
cancel_endpoint,
|
||||
timeout=cancel_timeout,
|
||||
max_retries=0,
|
||||
wait_label="Cancelling task",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
raise
|
||||
if not is_queued:
|
||||
consumed_attempts += 1
|
||||
|
||||
raise Exception(
|
||||
f"Polling timed out after {max_poll_attempts} non-queued attempts "
|
||||
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
|
||||
)
|
||||
except ProcessingInterrupted:
|
||||
raise
|
||||
except (LocalNetworkError, ApiServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"Polling aborted due to error: {e}") from e
|
||||
finally:
|
||||
stop_ticker.set()
|
||||
with contextlib.suppress(Exception):
|
||||
await ticker_task
|
||||
|
||||
|
||||
def _display_text(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
text: Optional[str],
|
||||
*,
|
||||
status: Optional[Union[str, int]] = None,
|
||||
price: Optional[float] = None,
|
||||
) -> None:
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
if price is not None:
|
||||
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
||||
if p != "0":
|
||||
display_lines.append(f"Price: ${p}")
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||||
|
||||
|
||||
def _display_time_progress(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
status: Optional[Union[str, int]],
|
||||
elapsed_seconds: int,
|
||||
estimated_total: Optional[int] = None,
|
||||
*,
|
||||
price: Optional[float] = None,
|
||||
is_queued: Optional[bool] = None,
|
||||
processing_elapsed_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
remaining = max(0, int(estimated_total) - int(pe))
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||||
_display_text(node_cls, time_line, status=status, price=price)
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get("https://www.google.com") as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
if not results["internet_accessible"]:
|
||||
return results
|
||||
|
||||
parsed = urlparse(default_base_url())
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(health_url) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
return results
|
||||
|
||||
|
||||
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||
"""Normalize (filename, value, content_type)."""
|
||||
if len(t) == 2:
|
||||
return t[0], t[1], "application/octet-stream"
|
||||
if len(t) == 3:
|
||||
return t[0], t[1], t[2]
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||||
params = dict(endpoint_params or {})
|
||||
if method.upper() == "GET" and data:
|
||||
for k, v in data.items():
|
||||
if v is not None:
|
||||
params[k] = v
|
||||
return params
|
||||
|
||||
|
||||
def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 401:
|
||||
return "Unauthorized: Please login first to use this node."
|
||||
if status == 402:
|
||||
return "Payment Required: Please add credits to your account to use this node."
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
if isinstance(err, dict):
|
||||
msg = err.get("message")
|
||||
typ = err.get("type")
|
||||
if msg and typ:
|
||||
return f"API Error: {msg} (Type: {typ})"
|
||||
if msg:
|
||||
return f"API Error: {msg}"
|
||||
return f"API Error: {json.dumps(body)}"
|
||||
else:
|
||||
txt = str(body)
|
||||
if len(txt) <= 200:
|
||||
return f"API Error (raw): {txt}"
|
||||
return f"API Error (status {status})"
|
||||
except Exception:
|
||||
return f"HTTP {status}: Unknown error"
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||||
slug = path.strip("/").replace("/", "_") or "op"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _snapshot_request_body_for_logging(
|
||||
content_type: str,
|
||||
method: str,
|
||||
data: Optional[dict[str, Any]],
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||||
) -> Optional[Union[dict[str, Any], str]]:
|
||||
if method.upper() == "GET":
|
||||
return None
|
||||
if content_type == "multipart/form-data":
|
||||
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
|
||||
file_fields: list[dict[str, str]] = []
|
||||
if files:
|
||||
file_iter = files if isinstance(files, list) else list(files.items())
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
if isinstance(file_obj, tuple):
|
||||
filename = file_obj[0]
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
return data or {}
|
||||
return data or {}
|
||||
|
||||
|
||||
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||||
url = cfg.endpoint.path
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
|
||||
method = cfg.endpoint.method
|
||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||
|
||||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||||
"""Every second: update elapsed time and signal interruption."""
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return # normal shutdown
|
||||
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: Optional[int] = None
|
||||
extracted_price: Optional[float] = None
|
||||
while True:
|
||||
attempt += 1
|
||||
stop_event = asyncio.Event()
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||||
if method == "GET":
|
||||
payload_headers.pop("Content-Type", None)
|
||||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||||
try:
|
||||
if cfg.monitor_progress:
|
||||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
if cfg.content_type == "multipart/form-data" and method != "GET":
|
||||
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
|
||||
payload_headers.pop("Content-Type", None)
|
||||
if cfg.multipart_parser and cfg.data:
|
||||
form = cfg.multipart_parser(cfg.data)
|
||||
if not isinstance(form, aiohttp.FormData):
|
||||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if cfg.data:
|
||||
for k, v in cfg.data.items():
|
||||
if v is None:
|
||||
continue
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
if cfg.files:
|
||||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
if isinstance(file_obj, tuple):
|
||||
filename, file_value, content_type = _unpack_tuple(file_obj)
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_value = file_obj
|
||||
content_type = "application/octet-stream"
|
||||
# Attempt to rewind BytesIO for retries
|
||||
if isinstance(file_value, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file_value.seek(0)
|
||||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||||
payload_kw["data"] = form
|
||||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
payload_kw["data"] = cfg.data or {}
|
||||
elif method != "GET":
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
|
||||
# Race: request vs. monitor (interruption)
|
||||
tasks = {req_task}
|
||||
if monitor_task:
|
||||
tasks.add(monitor_task)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task and monitor_task in done:
|
||||
# Interrupted – cancel the request and abort
|
||||
if req_task in pending:
|
||||
req_task.cancel()
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
# Otherwise, request finished
|
||||
resp = await req_task
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
buff = bytearray()
|
||||
last_tick = time.monotonic()
|
||||
async for chunk in resp.content.iter_chunked(64 * 1024):
|
||||
buff.extend(chunk)
|
||||
now = time.monotonic()
|
||||
if now - last_tick >= 1.0:
|
||||
last_tick = now
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
payload = await resp.json()
|
||||
response_content_to_log: Any = payload
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
text = await resp.text()
|
||||
try:
|
||||
payload = json.loads(text) if text else {}
|
||||
except json.JSONDecodeError:
|
||||
payload = {"_raw": text}
|
||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
) from e
|
||||
finally:
|
||||
stop_event.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||||
_display_time_progress(
|
||||
cfg.node_cls,
|
||||
status=cfg.final_label_on_success,
|
||||
elapsed_seconds=(
|
||||
final_elapsed_seconds
|
||||
if final_elapsed_seconds is not None
|
||||
else int(time.monotonic() - start_time)
|
||||
),
|
||||
estimated_total=cfg.estimated_total,
|
||||
price=extracted_price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=final_elapsed_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||
try:
|
||||
return response_model.model_validate(payload)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Response validation failed for %s: %s",
|
||||
getattr(response_model, "__name__", response_model),
|
||||
e,
|
||||
)
|
||||
raise Exception(
|
||||
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _wrap_model_extractor(
|
||||
response_model: Type[M],
|
||||
extractor: Optional[Callable[[M], Any]],
|
||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||
Validates the dict into `response_model` before invoking `extractor`.
|
||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||
the same response for multiple extractors in a single poll attempt.
|
||||
"""
|
||||
if extractor is None:
|
||||
return None
|
||||
_cache: dict[int, M] = {}
|
||||
|
||||
def _wrapped(d: dict[str, Any]) -> Any:
|
||||
try:
|
||||
key = id(d)
|
||||
model = _cache.get(key)
|
||||
if model is None:
|
||||
model = response_model.model_validate(d)
|
||||
_cache[key] = model
|
||||
return extractor(model)
|
||||
except Exception as e:
|
||||
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
|
||||
raise
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||||
if not values:
|
||||
return set()
|
||||
out: set[Union[str, int]] = set()
|
||||
for v in values:
|
||||
nv = _normalize_status_value(v)
|
||||
if nv is not None:
|
||||
out.add(nv)
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||||
if isinstance(val, str):
|
||||
return val.strip().lower()
|
||||
return val
|
||||
14
comfy_api_nodes/util/common_exceptions.py
Normal file
14
comfy_api_nodes/util/common_exceptions.py
Normal file
@@ -0,0 +1,14 @@
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
|
||||
|
||||
class ProcessingInterrupted(Exception):
|
||||
"""Operation was interrupted by user/runtime via processing_interrupted()."""
|
||||
470
comfy_api_nodes/util/conversions.py
Normal file
470
comfy_api_nodes/util/conversions.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import Input, InputImpl
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
|
||||
from ._helpers import mimetype_to_extension
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: Input.Video,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = BytesIO()
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2**15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2**31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = _f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def resize_mask_to_image(
|
||||
mask: torch.Tensor,
|
||||
image: torch.Tensor,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled",
|
||||
allow_gradient=True,
|
||||
add_channel_dim=False,
|
||||
):
|
||||
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
|
||||
_, height, width, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1, 1)
|
||||
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
|
||||
mask = mask.movedim(1, -1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
262
comfy_api_nodes/util/download_helpers.py
Normal file
262
comfy_api_nodes/util/download_helpers.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO, Optional, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
)
|
||||
from .client import _diagnose_connectivity
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
from .conversions import bytesio_to_image_tensor
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str,
|
||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> None:
|
||||
"""Stream-download a URL to `dest`.
|
||||
|
||||
`dest` must be one of:
|
||||
- a BytesIO (rewound to 0 after write),
|
||||
- a file-like object opened in binary write mode (must implement .write()),
|
||||
- a filesystem path (str | pathlib.Path), which will be opened with 'wb'.
|
||||
|
||||
If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded
|
||||
to an absolute URL and authentication headers can be applied.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
|
||||
"""
|
||||
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
|
||||
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_auth_header(cls)
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
op_id = _generate_operation_id("GET", url, attempt)
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
|
||||
|
||||
is_path_sink = isinstance(dest, (str, Path))
|
||||
fhandle = None
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
stop_evt: Optional[asyncio.Event] = None
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
req_task: Optional[asyncio.Task] = None
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
|
||||
|
||||
session = aiohttp.ClientSession(timeout=timeout_cfg)
|
||||
stop_evt = asyncio.Event()
|
||||
|
||||
async def _monitor():
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
|
||||
req_task = asyncio.create_task(session.get(url, headers=headers))
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await req_task
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
try:
|
||||
resp = await req_task
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, ValueError):
|
||||
text = await resp.text()
|
||||
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=f"HTTP {resp.status}",
|
||||
)
|
||||
|
||||
if resp.status in _RETRY_STATUS and attempt <= max_retries:
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to download (HTTP {resp.status}).")
|
||||
|
||||
if is_path_sink:
|
||||
p = Path(str(dest))
|
||||
with contextlib.suppress(Exception):
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
fhandle = open(p, "wb")
|
||||
sink = fhandle
|
||||
else:
|
||||
sink = dest # BytesIO or file-like
|
||||
|
||||
written = 0
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
chunk = b""
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
if not chunk:
|
||||
if resp.content.at_eof():
|
||||
break
|
||||
continue
|
||||
|
||||
sink.write(chunk)
|
||||
written += len(chunk)
|
||||
|
||||
if isinstance(dest, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The remote service appears unreachable at this time.") from e
|
||||
finally:
|
||||
if stop_evt is not None:
|
||||
stop_evt.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if req_task and not req_task.done():
|
||||
req_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await req_task
|
||||
if session:
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
if fhandle:
|
||||
with contextlib.suppress(Exception):
|
||||
fhandle.flush()
|
||||
fhandle.close()
|
||||
|
||||
|
||||
async def download_url_to_image_tensor(
|
||||
url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||
return bytesio_to_image_tensor(result)
|
||||
|
||||
|
||||
async def download_url_to_video_output(
|
||||
video_url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
max_retries: int = 5,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||
return VideoFromFile(result)
|
||||
|
||||
|
||||
async def download_url_as_bytesio(
|
||||
url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> BytesIO:
|
||||
"""Downloads content from a URL and returns a new BytesIO (rewound to 0)."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||
return result
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
|
||||
except Exception:
|
||||
slug = "download"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
import folder_paths
|
||||
338
comfy_api_nodes/util/upload_helpers.py
Normal file
338
comfy_api_nodes/util/upload_helpers.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import IO, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||
from .client import (
|
||||
ApiEndpoint,
|
||||
_diagnose_connectivity,
|
||||
_display_time_progress,
|
||||
sync_op,
|
||||
)
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
from .conversions import (
|
||||
audio_ndarray_to_bytesio,
|
||||
audio_tensor_to_contiguous_ndarray,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
max_images: int = 8,
|
||||
mime_type: Optional[str] = None,
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
audio: Input.Audio,
|
||||
*,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.get_duration()
|
||||
if actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error getting video duration: %s", str(e))
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
) -> str:
|
||||
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
|
||||
data=request_object,
|
||||
response_model=UploadResponse,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
await upload_file(
|
||||
cls,
|
||||
create_resp.upload_url,
|
||||
file_bytes_io,
|
||||
content_type=upload_mime_type,
|
||||
wait_label=wait_label,
|
||||
)
|
||||
return create_resp.download_url
|
||||
|
||||
|
||||
async def upload_file(
|
||||
cls: type[IO.ComfyNode],
|
||||
upload_url: str,
|
||||
file: Union[BytesIO, str],
|
||||
*,
|
||||
content_type: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
||||
|
||||
Args:
|
||||
cls: Node class (provides auth context + UI progress hooks).
|
||||
upload_url: Pre-signed PUT URL.
|
||||
file: BytesIO or path string.
|
||||
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
||||
max_retries: Maximum retry attempts.
|
||||
retry_delay: Initial delay in seconds.
|
||||
retry_backoff: Exponential backoff factor.
|
||||
wait_label: Progress label shown in Comfy UI.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
||||
"""
|
||||
if isinstance(file, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file.seek(0)
|
||||
data = file.read()
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
raise ValueError("file must be a BytesIO or a filesystem path string")
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
skip_auto_headers: set[str] = set()
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
else:
|
||||
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
start_ts = time.monotonic()
|
||||
op_uuid = uuid.uuid4().hex[:8]
|
||||
while True:
|
||||
attempt += 1
|
||||
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
|
||||
timeout = aiohttp.ClientTimeout(total=None)
|
||||
stop_evt = asyncio.Event()
|
||||
|
||||
async def _monitor():
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
if wait_label:
|
||||
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
req_task = asyncio.create_task(req)
|
||||
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
raise ProcessingInterrupted("Upload cancelled")
|
||||
|
||||
try:
|
||||
resp = await req_task
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Upload cancelled") from None
|
||||
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
body = await resp.json()
|
||||
except Exception:
|
||||
body = await resp.text()
|
||||
msg = f"Upload failed with status {resp.status}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
wait_label,
|
||||
start_ts,
|
||||
None,
|
||||
display_callback=_display_time_progress if wait_label else None,
|
||||
)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
wait_label,
|
||||
start_ts,
|
||||
None,
|
||||
display_callback=_display_time_progress if wait_label else None,
|
||||
)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The API service appears unreachable at this time.") from e
|
||||
finally:
|
||||
stop_evt.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
|
||||
except Exception:
|
||||
slug = "upload"
|
||||
return f"{method}_{slug}_{op_uuid}_try{attempt}"
|
||||
@@ -2,6 +2,8 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
@@ -28,76 +30,69 @@ def validate_image_dimensions(
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Image height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_image_aspect_ratio(
|
||||
image: torch.Tensor,
|
||||
min_aspect_ratio: Optional[float] = None,
|
||||
max_aspect_ratio: Optional[float] = None,
|
||||
):
|
||||
width, height = get_image_dimensions(image)
|
||||
aspect_ratio = width / height
|
||||
|
||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
|
||||
|
||||
def validate_image_aspect_ratio_range(
|
||||
image: torch.Tensor,
|
||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
a1, b1 = min_ratio
|
||||
a2, b2 = max_ratio
|
||||
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
||||
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
||||
lo, hi = (a1 / b1), (a2 / b2)
|
||||
if lo > hi:
|
||||
lo, hi = hi, lo
|
||||
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
||||
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
|
||||
w, h = get_image_dimensions(image)
|
||||
if w <= 0 or h <= 0:
|
||||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||
ar = w / h
|
||||
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
||||
if not ok:
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
||||
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||
return ar
|
||||
|
||||
|
||||
def validate_aspect_ratio_closeness(
|
||||
start_img,
|
||||
end_img,
|
||||
min_rel: float,
|
||||
max_rel: float,
|
||||
def validate_images_aspect_ratio_closeness(
|
||||
first_image: torch.Tensor,
|
||||
second_image: torch.Tensor,
|
||||
min_rel: float, # e.g. 0.8
|
||||
max_rel: float, # e.g. 1.25
|
||||
*,
|
||||
strict: bool = False, # True => exclusive, False => inclusive
|
||||
) -> None:
|
||||
w1, h1 = get_image_dimensions(start_img)
|
||||
w2, h2 = get_image_dimensions(end_img)
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
"""
|
||||
Validates that the two images' aspect ratios are 'close'.
|
||||
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
|
||||
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
|
||||
|
||||
Returns the computed closeness factor C.
|
||||
"""
|
||||
w1, h1 = get_image_dimensions(first_image)
|
||||
w2, h2 = get_image_dimensions(second_image)
|
||||
if min(w1, h1, w2, h2) <= 0:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
ar1 = w1 / h1
|
||||
ar2 = w2 / h2
|
||||
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
||||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
||||
limit = max(max_rel, 1.0 / min_rel)
|
||||
if (closeness >= limit) if strict else (closeness > limit):
|
||||
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
||||
raise ValueError(
|
||||
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
|
||||
f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})."
|
||||
)
|
||||
return closeness
|
||||
|
||||
|
||||
def validate_aspect_ratio_string(
|
||||
aspect_ratio: str,
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
|
||||
ar = _parse_aspect_ratio_string(aspect_ratio)
|
||||
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||
return ar
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
@@ -118,9 +113,7 @@ def validate_video_dimensions(
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Video height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||||
|
||||
@@ -138,13 +131,9 @@ def validate_video_duration(
|
||||
|
||||
epsilon = 0.0001
|
||||
if min_duration is not None and min_duration - epsilon > duration:
|
||||
raise ValueError(
|
||||
f"Video duration must be at least {min_duration}s, got {duration}s"
|
||||
)
|
||||
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
|
||||
if max_duration is not None and duration > max_duration + epsilon:
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||
|
||||
|
||||
def get_number_of_images(images):
|
||||
@@ -165,3 +154,77 @@ def validate_audio_duration(
|
||||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||
if max_duration is not None and dur - eps > max_duration:
|
||||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
|
||||
|
||||
def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||||
a, b = r
|
||||
if a <= 0 or b <= 0:
|
||||
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
|
||||
return a / b
|
||||
|
||||
|
||||
def _assert_ratio_bounds(
|
||||
ar: float,
|
||||
*,
|
||||
min_ratio: Optional[tuple[float, float]] = None,
|
||||
max_ratio: Optional[tuple[float, float]] = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
|
||||
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
|
||||
|
||||
if lo is not None and hi is not None and lo > hi:
|
||||
lo, hi = hi, lo # normalize order if caller swapped them
|
||||
|
||||
if lo is not None:
|
||||
if (ar <= lo) if strict else (ar < lo):
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
|
||||
if hi is not None:
|
||||
if (ar >= hi) if strict else (ar > hi):
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
|
||||
|
||||
|
||||
def _parse_aspect_ratio_string(ar_str: str) -> float:
|
||||
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
|
||||
parts = ar_str.split(":")
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
|
||||
try:
|
||||
a = int(parts[0].strip())
|
||||
b = int(parts[1].strip())
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
|
||||
if a <= 0 or b <= 0:
|
||||
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
|
||||
return a / b
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -48,7 +53,7 @@ class Unhashable:
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
@@ -188,6 +193,9 @@ class BasicCache:
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
@@ -265,6 +273,29 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class NullCache:
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
pass
|
||||
|
||||
def all_node_ids(self):
|
||||
return []
|
||||
|
||||
def clean_unused(self):
|
||||
pass
|
||||
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
return None
|
||||
|
||||
def set(self, node_id, value):
|
||||
pass
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
return self
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
@@ -318,155 +349,75 @@ class LRUCache(BasicCache):
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||
|
||||
RAM_CACHE_HYSTERESIS = 1.1
|
||||
|
||||
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||
|
||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||
|
||||
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||
#in constantly changing setups.
|
||||
|
||||
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
self._clean_subcaches()
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
def set(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
super().set(node_id, value)
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
def get(self, node_id):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
def scan_list_for_ram_usage(outputs):
|
||||
nonlocal ram_usage
|
||||
if outputs is None:
|
||||
return
|
||||
for output in outputs:
|
||||
if isinstance(output, list):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||
#be high value intermediates
|
||||
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||
elif hasattr(output, "get_ram_usage"):
|
||||
ram_usage += output.get_ram_usage()
|
||||
scan_list_for_ram_usage(outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
gc.collect()
|
||||
|
||||
@@ -153,8 +153,9 @@ class TopologicalSort:
|
||||
continue
|
||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
if (include_lazy or not is_lazy):
|
||||
if not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
|
||||
for link in links:
|
||||
@@ -194,10 +195,40 @@ class ExecutionList(TopologicalSort):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
self.execution_cache = {}
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if not from_node_id in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
|
||||
def get_cache(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
return None
|
||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||
if value is None:
|
||||
return None
|
||||
#Write back to the main cache on touch.
|
||||
self.output_cache.set(from_node_id, value)
|
||||
return value
|
||||
|
||||
def cache_update(self, node_id, value):
|
||||
if node_id in self.execution_cache_listeners:
|
||||
for to_node_id in self.execution_cache_listeners[node_id]:
|
||||
if to_node_id in self.execution_cache:
|
||||
self.execution_cache[to_node_id][node_id] = value
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
self.cache_link(from_node_id, to_node_id)
|
||||
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
@@ -277,6 +308,8 @@ class ExecutionList(TopologicalSort):
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.execution_cache.pop(node_id, None)
|
||||
self.execution_cache_listeners.pop(node_id, None)
|
||||
self.staged_node_id = None
|
||||
|
||||
def get_nodes_in_cycle(self):
|
||||
|
||||
@@ -142,9 +142,10 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
@@ -156,7 +157,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
@@ -165,9 +166,9 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||
import nodes
|
||||
import comfy.utils
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class SetUnionControlNetType:
|
||||
class SetUnionControlNetType(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"control_net": ("CONTROL_NET", ),
|
||||
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SetUnionControlNetType",
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.ControlNet.Input("control_net"),
|
||||
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
|
||||
],
|
||||
outputs=[
|
||||
io.ControlNet.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
|
||||
FUNCTION = "set_controlnet_type"
|
||||
|
||||
def set_controlnet_type(self, control_net, type):
|
||||
@classmethod
|
||||
def execute(cls, control_net, type) -> io.NodeOutput:
|
||||
control_net = control_net.copy()
|
||||
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
||||
if type_number >= 0:
|
||||
@@ -22,27 +28,36 @@ class SetUnionControlNetType:
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
return (control_net,)
|
||||
return io.NodeOutput(control_net)
|
||||
|
||||
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||
set_controlnet_type = execute # TODO: remove
|
||||
|
||||
|
||||
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"control_net": ("CONTROL_NET", ),
|
||||
"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ControlNetInpaintingAliMamaApply",
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.ControlNet.Input("control_net"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Mask.Input("mask"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
FUNCTION = "apply_inpaint_controlnet"
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
|
||||
extra_concat = []
|
||||
if control_net.concat_mask:
|
||||
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
@@ -50,11 +65,20 @@ class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
||||
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||
extra_concat = [mask]
|
||||
|
||||
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||
return io.NodeOutput(result[0], result[1])
|
||||
|
||||
apply_inpaint_controlnet = execute # TODO: remove
|
||||
|
||||
|
||||
class ControlNetExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SetUnionControlNetType,
|
||||
ControlNetInpaintingAliMamaApply,
|
||||
]
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> ControlNetExtension:
|
||||
return ControlNetExtension()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user