mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 11:40:02 +00:00
Compare commits
82 Commits
v0.3.47
...
node-memor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c611b0b99 | ||
|
|
cd54d502fc | ||
|
|
63571c6c3d | ||
|
|
bae0c31a68 | ||
|
|
34b1f51f4a | ||
|
|
bd2ab73976 | ||
|
|
da2efeaec6 | ||
|
|
7f3b9b16c6 | ||
|
|
d4e353a94e | ||
|
|
ed43784b0d | ||
|
|
0f2b8525bc | ||
|
|
20a84166d0 | ||
|
|
ed2e33c69a | ||
|
|
1702e6df16 | ||
|
|
c308a8840a | ||
|
|
027c63f63a | ||
|
|
e08ecfbd8a | ||
|
|
4e5c230f6a | ||
|
|
f0d5d0111f | ||
|
|
ad19a069f6 | ||
|
|
5d65d6753b | ||
|
|
deebee4ff6 | ||
|
|
fa570cbf59 | ||
|
|
644b23ac0b | ||
|
|
72fd4d22b6 | ||
|
|
e4f7ea105f | ||
|
|
c991a5da65 | ||
|
|
9df8792d4b | ||
|
|
3da5a07510 | ||
|
|
afa0a45206 | ||
|
|
615eb52049 | ||
|
|
d5c1954d5c | ||
|
|
e400f26c8f | ||
|
|
5ca8e2fac3 | ||
|
|
3294782d19 | ||
|
|
898d88e10e | ||
|
|
560d38f34c | ||
|
|
e1d4f36d8d | ||
|
|
1e3ae1eed8 | ||
|
|
f4231a80b1 | ||
|
|
2208aa616d | ||
|
|
629b173837 | ||
|
|
fa340add55 | ||
|
|
966f3a5206 | ||
|
|
0552de7c7d | ||
|
|
5828607ccf | ||
|
|
735bb4bdb1 | ||
|
|
bf2a1b5b1e | ||
|
|
42974a448c | ||
|
|
05df2df489 | ||
|
|
37d620a6b8 | ||
|
|
32691b16f4 | ||
|
|
4c3e57b0ae | ||
|
|
9126c0cfe4 | ||
|
|
d8c51ba15a | ||
|
|
32a95bba8a | ||
|
|
da1ad9b516 | ||
|
|
d044a24398 | ||
|
|
5be6fd09ff | ||
|
|
f69609bbd6 | ||
|
|
c012400240 | ||
|
|
03895dea7c | ||
|
|
84f9759424 | ||
|
|
7991341e89 | ||
|
|
140ffc7fdc | ||
|
|
182f90b5ec | ||
|
|
aebac22193 | ||
|
|
13aaa66ec2 | ||
|
|
5f582a9757 | ||
|
|
fbcc23945d | ||
|
|
3dfefc88d0 | ||
|
|
bff60b5cfc | ||
|
|
1e638a140b | ||
|
|
4696d74305 | ||
|
|
5ee381c058 | ||
|
|
4887743a2a | ||
|
|
97b8a2c26a | ||
|
|
97eb256a35 | ||
|
|
61b08d4ba6 | ||
|
|
da9dab7edd | ||
|
|
d2aaef029c | ||
|
|
0a3d062e06 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,2 +1,3 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
||||
comfy_api_nodes/apis/__init__.py linguist-generated
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -22,7 +22,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@@ -18,7 +18,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
||||
17
.github/workflows/stable-release.yml
vendored
17
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -66,8 +66,13 @@ jobs:
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
@@ -85,7 +90,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
@@ -17,19 +17,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -64,6 +64,10 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
@@ -82,7 +86,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
27
CODEOWNERS
27
CODEOWNERS
@@ -5,20 +5,21 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
|
||||
35
README.md
35
README.md
@@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
## Get Started
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
|
||||
#### [Windows Portable Package](#installing)
|
||||
@@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
@@ -111,7 +112,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
@@ -202,7 +203,7 @@ Put your VAE in: models/vae
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
|
||||
@@ -210,33 +211,25 @@ This is the command to install the nightly with ROCm 6.4 which might have some p
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch nightly, use the following command:
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
|
||||
|
||||
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||
|
||||
```
|
||||
conda install libuv
|
||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
```
|
||||
|
||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
@@ -351,7 +344,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
|
||||
|
||||
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
||||
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
@@ -130,10 +130,21 @@ class ModelFileManager:
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
full_path = os.path.join(dirpath, file_name)
|
||||
relative_path = os.path.relpath(full_path, directory)
|
||||
|
||||
# Get file metadata
|
||||
file_info = {
|
||||
"name": relative_path,
|
||||
"pathIndex": pathIndex,
|
||||
"modified": os.path.getmtime(full_path), # Add modification time
|
||||
"created": os.path.getctime(full_path), # Add creation time
|
||||
"size": os.path.getsize(full_path) # Add file size
|
||||
}
|
||||
result.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
@@ -144,7 +155,7 @@ class ModelFileManager:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
return result, dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
@@ -20,13 +20,15 @@ class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
created: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
"modified": os.path.getmtime(path),
|
||||
"created": os.path.getctime(path)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
@@ -10,12 +11,15 @@ class CONDRegular:
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@@ -29,14 +33,14 @@ class CONDRegular:
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
def process_cond(self, batch_size, area, **kwargs):
|
||||
data = self.cond
|
||||
if area is not None:
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
@@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
@@ -92,10 +99,10 @@ class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
|
||||
540
comfy/context_windows.py
Normal file
540
comfy/context_windows.py
Normal file
@@ -0,0 +1,540 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
|
||||
class ContextWindowABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get torch.Tensor applicable to current window.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
class ContextHandlerABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
@dataclass
|
||||
class ContextFuseMethod:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
self.context_overlap = context_overlap
|
||||
self.context_stride = context_stride
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
if control.previous_controlnet is not None:
|
||||
self.prepare_control_objects(control.previous_controlnet, device)
|
||||
return control
|
||||
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
||||
if cond_in is None:
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
||||
for key in actual_cond:
|
||||
try:
|
||||
cond_item = actual_cond[key]
|
||||
if isinstance(cond_item, torch.Tensor):
|
||||
# check that tensor is the expected length - x.size(0)
|
||||
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
||||
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
||||
actual_cond_item = window.get_tensor(cond_item)
|
||||
resized_actual_cond[key] = actual_cond_item.to(device)
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item.to(device)
|
||||
# look for control
|
||||
elif key == "control":
|
||||
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
||||
elif isinstance(cond_item, dict):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
resized_actual_cond[key] = new_cond_item
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item
|
||||
finally:
|
||||
del cond_item # just in case to prevent VRAM issues
|
||||
resized_cond.append(resized_actual_cond)
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
for pos, idx in enumerate(window.index_list):
|
||||
# bias is the influence of a specific index in relation to the whole context window
|
||||
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
||||
bias = max(1e-2, bias)
|
||||
# take weighted average relative to total bias of current idx
|
||||
for i in range(len(sub_conds_out)):
|
||||
bias_total = biases_final[i][idx]
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
window.add_window(counts_final[i], weights_tensor)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||
"ContextWindows_prepare_sampling",
|
||||
_prepare_sampling_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
for _ in range(dim):
|
||||
weights_tensor = weights_tensor.unsqueeze(0)
|
||||
for _ in range(total_dims - dim - 1):
|
||||
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||
return weights_tensor
|
||||
|
||||
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||
total_dims = len(x_in.shape)
|
||||
shape = []
|
||||
for _ in range(dim):
|
||||
shape.append(1)
|
||||
shape.append(x_in.shape[dim])
|
||||
for _ in range(total_dims - dim - 1):
|
||||
shape.append(1)
|
||||
return shape
|
||||
|
||||
class ContextSchedules:
|
||||
UNIFORM_LOOPED = "looped_uniform"
|
||||
UNIFORM_STANDARD = "standard_uniform"
|
||||
STATIC_STANDARD = "standard_static"
|
||||
BATCHED = "batched"
|
||||
|
||||
|
||||
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
||||
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames < handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
return windows
|
||||
|
||||
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||
# instead, they get shifted to the corresponding end of the frames.
|
||||
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# first, obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (-handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||
delete_idxs = []
|
||||
win_i = 0
|
||||
while win_i < len(windows):
|
||||
# if window is rolls over itself, need to shift it
|
||||
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||
if is_roll:
|
||||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||
# check if next window (cyclical) is missing roll_val
|
||||
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||
# need to insert new window here - just insert window starting at roll_val
|
||||
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||
# delete window if it's not unique
|
||||
for pre_i in range(0, win_i):
|
||||
if windows[win_i] == windows[pre_i]:
|
||||
delete_idxs.append(win_i)
|
||||
break
|
||||
win_i += 1
|
||||
|
||||
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||
delete_idxs.reverse()
|
||||
for i in delete_idxs:
|
||||
windows.pop(i)
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows
|
||||
delta = handler.context_length - handler.context_overlap
|
||||
for start_idx in range(0, num_frames, delta):
|
||||
# if past the end of frames, move start_idx back to allow same context_length
|
||||
ending = start_idx + handler.context_length
|
||||
if ending >= num_frames:
|
||||
final_delta = ending - num_frames
|
||||
final_start_idx = start_idx - final_delta
|
||||
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
||||
break
|
||||
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows;
|
||||
# no overlap, just cut up based on context_length;
|
||||
# last window size will be different if num_frames % opts.context_length != 0
|
||||
for start_idx in range(0, num_frames, handler.context_length):
|
||||
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
||||
return [list(range(num_frames))]
|
||||
|
||||
|
||||
CONTEXT_MAPPING = {
|
||||
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
||||
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
||||
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
||||
ContextSchedules.BATCHED: create_windows_batched,
|
||||
}
|
||||
|
||||
|
||||
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
func = CONTEXT_MAPPING.get(context_schedule, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
# weight is the same for all
|
||||
return [1.0] * length
|
||||
|
||||
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
# weight is based on the distance away from the edge of the context window;
|
||||
# based on weighted average concept in FreeNoise paper
|
||||
if length % 2 == 0:
|
||||
max_weight = length // 2
|
||||
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
||||
else:
|
||||
max_weight = (length + 1) // 2
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
FLAT = "flat"
|
||||
PYRAMID = "pyramid"
|
||||
RELATIVE = "relative"
|
||||
OVERLAP_LINEAR = "overlap-linear"
|
||||
|
||||
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
||||
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
||||
|
||||
|
||||
FUSE_MAPPING = {
|
||||
ContextFuseMethods.FLAT: create_weights_flat,
|
||||
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
||||
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
||||
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||
}
|
||||
|
||||
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||
func = FUSE_MAPPING.get(fuse_method, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||
return ContextFuseMethod(fuse_method, func)
|
||||
|
||||
# Returns fraction that has denominator that is a power of 2
|
||||
def ordered_halving(val):
|
||||
# get binary value, padded with 0s for 64 bits
|
||||
bin_str = f"{val:064b}"
|
||||
# flip binary value, padding included
|
||||
bin_flip = bin_str[::-1]
|
||||
# convert binary to int
|
||||
as_int = int(bin_flip, 2)
|
||||
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
||||
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
||||
return as_int / (1 << 64)
|
||||
|
||||
|
||||
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||
all_indexes = list(range(num_frames))
|
||||
for w in windows:
|
||||
for val in w:
|
||||
try:
|
||||
all_indexes.remove(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return all_indexes
|
||||
|
||||
|
||||
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||
prev_val = -1
|
||||
for i, val in enumerate(window):
|
||||
val = val % num_frames
|
||||
if val < prev_val:
|
||||
return True, i
|
||||
prev_val = val
|
||||
return False, -1
|
||||
|
||||
|
||||
def shift_window_to_start(window: list[int], num_frames: int):
|
||||
start_val = window[0]
|
||||
for i in range(len(window)):
|
||||
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||
# 2) add num_frames and take modulus to get adjusted vals
|
||||
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||
|
||||
|
||||
def shift_window_to_end(window: list[int], num_frames: int):
|
||||
# 1) shift window to start
|
||||
shift_window_to_start(window, num_frames)
|
||||
end_val = window[-1]
|
||||
end_delta = num_frames - end_val - 1
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
@@ -28,6 +28,7 @@ import comfy.model_detection
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
import comfy.model_base
|
||||
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
@@ -43,7 +44,6 @@ if TYPE_CHECKING:
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@@ -265,12 +265,12 @@ class ControlNet(ControlBase):
|
||||
for c in self.extra_conds:
|
||||
temp = cond.get(c, None)
|
||||
if temp is not None:
|
||||
extra[c] = temp.to(dtype)
|
||||
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
|
||||
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
|
||||
@@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
# x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -224,19 +224,27 @@ class Flux(nn.Module):
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
||||
for ref in ref_latents:
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
h_offset = h
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
|
||||
@@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module):
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.shape[0] > 1:
|
||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||
q[i : i + SDP_BATCH_LIMIT],
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
@@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
|
||||
443
comfy/ldm/qwen_image/model.py
Normal file
443
comfy/ldm/qwen_image/model.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
|
||||
self.approximate = approximate
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
eps: float = 1e-5,
|
||||
out_bias: bool = True,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.dropout = dropout
|
||||
|
||||
# Q/K normalization
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
# Image stream projections
|
||||
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
# Text stream projections
|
||||
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
# Output projections
|
||||
self.to_out = nn.ModuleList([
|
||||
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
])
|
||||
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor, # Image stream
|
||||
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
|
||||
img_attn_output = self.to_out[0](img_attn_output)
|
||||
img_attn_output = self.to_out[1](img_attn_output)
|
||||
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
eps=eps,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
|
||||
img_attn_output, txt_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_layers: int = 60,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
QwenImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
|
||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
@@ -391,6 +391,7 @@ class WanModel(torch.nn.Module):
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
in_dim_ref_conv=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -484,6 +485,11 @@ class WanModel(torch.nn.Module):
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
if in_dim_ref_conv is not None:
|
||||
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
else:
|
||||
self.ref_conv = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
@@ -526,6 +532,13 @@ class WanModel(torch.nn.Module):
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
full_ref = None
|
||||
if self.ref_conv is not None:
|
||||
full_ref = kwargs.get("reference_latent", None)
|
||||
if full_ref is not None:
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@@ -552,6 +565,9 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
@@ -570,6 +586,9 @@ class WanModel(torch.nn.Module):
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
@@ -749,7 +768,12 @@ class CameraWanModel(WanModel):
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
if model_type == 'camera':
|
||||
model_type = 'i2v'
|
||||
else:
|
||||
model_type = 't2v'
|
||||
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
@@ -769,8 +793,7 @@ class CameraWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
x = x + self.control_adapter(camera_conditions).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
@@ -166,7 +171,7 @@ class ResidualBlock(nn.Module):
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@@ -178,12 +183,12 @@ class ResidualBlock(nn.Module):
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
return x + self.shortcut(old_x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
@@ -151,7 +151,7 @@ class ResidualBlock(nn.Module):
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
@@ -293,6 +293,16 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.QwenImage):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
# Direct mapping for transformer_blocks format (QwenImage LoRA format)
|
||||
key_map["{}".format(key_lora)] = k
|
||||
# Support transformer prefix format
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -106,10 +107,12 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
def convert_tensor(extra, dtype, device):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra = comfy.model_management.cast_to_device(extra, device, dtype)
|
||||
else:
|
||||
extra = comfy.model_management.cast_to_device(extra, device, None)
|
||||
return extra
|
||||
|
||||
|
||||
@@ -160,7 +163,7 @@ class BaseModel(torch.nn.Module):
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
@@ -169,20 +172,21 @@ class BaseModel(torch.nn.Module):
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
if context is not None:
|
||||
context = context.to(dtype)
|
||||
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
extra = convert_tensor(extra, dtype)
|
||||
extra = convert_tensor(extra, dtype, device)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
ex.append(convert_tensor(ext, dtype, device))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
@@ -398,7 +402,7 @@ class SD21UNCLIP(BaseModel):
|
||||
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
||||
device = kwargs["device"]
|
||||
if unclip_conditioning is None:
|
||||
return torch.zeros((1, self.adm_channels))
|
||||
return torch.zeros((1, self.adm_channels), device=device)
|
||||
else:
|
||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
||||
|
||||
@@ -612,9 +616,11 @@ class IP2P:
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
else:
|
||||
image = image.to(device=device)
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
@@ -693,7 +699,7 @@ class StableCascade_B(BaseModel):
|
||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||
|
||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
||||
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
return out
|
||||
|
||||
@@ -884,6 +890,10 @@ class Flux(BaseModel):
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
@@ -1118,7 +1128,11 @@ class WAN21(BaseModel):
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((mask, image), dim=1)
|
||||
concat_mask_index = kwargs.get("concat_mask_index", 0)
|
||||
if concat_mask_index != 0:
|
||||
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
|
||||
else:
|
||||
return torch.cat((mask, image), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1134,6 +1148,10 @@ class WAN21(BaseModel):
|
||||
if time_dim_concat is not None:
|
||||
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1158,10 +1176,10 @@ class WAN21_Vace(WAN21):
|
||||
|
||||
vace_frames_out = []
|
||||
for j in range(len(vace_frames)):
|
||||
vf = vace_frames[j].clone()
|
||||
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
|
||||
for i in range(0, vf.shape[1], 16):
|
||||
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
||||
vf = torch.cat([vf, mask[j]], dim=1)
|
||||
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
|
||||
vace_frames_out.append(vf)
|
||||
|
||||
vace_frames = torch.stack(vace_frames_out, dim=1)
|
||||
@@ -1303,3 +1321,24 @@ class Omnigen2(BaseModel):
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class QwenImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
return out
|
||||
|
||||
@@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
dit_config["model_type"] = "camera_2.2"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
@@ -373,6 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||
if flf_weight is not None:
|
||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||
|
||||
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
@@ -481,6 +489,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -867,7 +880,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
|
||||
@@ -78,7 +78,6 @@ try:
|
||||
torch_version = torch.version.__version__
|
||||
temp = torch_version.split(".")
|
||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -102,10 +101,14 @@ if args.directml is not None:
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = xpu_available or torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = False
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -321,9 +324,9 @@ try:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 8):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
# if torch_version_numeric >= (2, 8):
|
||||
# if any((a in arch) for a in ["gfx1201"]):
|
||||
# ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
SUPPORT_FP8_OPS = True
|
||||
@@ -340,7 +343,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
|
||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||
try:
|
||||
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||
logging.info("Enabled fp16 accumulation.")
|
||||
@@ -579,16 +582,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def get_models_memory_reserve(models):
|
||||
total_reserve = 0
|
||||
for model in models:
|
||||
total_reserve += model.get_model_memory_reserve(convert_to_bytes=True)
|
||||
return total_reserve
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||
models_memory_reserve = get_models_memory_reserve(models)
|
||||
extra_mem = max(inference_memory + models_memory_reserve, memory_required + extra_reserved_memory() + models_memory_reserve)
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
minimum_memory_required = max(inference_memory + models_memory_reserve, minimum_memory_required + extra_reserved_memory() + models_memory_reserve)
|
||||
|
||||
models = set(models)
|
||||
|
||||
@@ -946,10 +956,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
return dtype
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if args.force_non_blocking:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
|
||||
return False
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_enabled:
|
||||
@@ -1282,10 +1294,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 6):
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
@@ -24,7 +24,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -84,6 +84,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
def add_model_options_memory_reserve(model_options, memory_reserve_gb: float):
|
||||
if "model_memory_reserve" not in model_options:
|
||||
model_options["model_memory_reserve"] = []
|
||||
model_options["model_memory_reserve"].append(memory_reserve_gb)
|
||||
return model_options
|
||||
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
@@ -439,6 +445,17 @@ class ModelPatcher:
|
||||
self.force_cast_weights = True
|
||||
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
||||
|
||||
def add_model_memory_reserve(self, memory_reserve_gb: float):
|
||||
"""Adds additional expected memory usage for the model, in gigabytes."""
|
||||
self.model_options = add_model_options_memory_reserve(self.model_options, memory_reserve_gb)
|
||||
|
||||
def get_model_memory_reserve(self, convert_to_bytes: bool = False) -> Union[float, int]:
|
||||
"""Returns the total expected memory usage for the model in gigabytes, or bytes if convert_to_bytes is True."""
|
||||
total_reserve = sum(self.model_options.get("model_memory_reserve", []))
|
||||
if convert_to_bytes:
|
||||
return total_reserve * 1024 * 1024 * 1024
|
||||
return total_reserve
|
||||
|
||||
def add_weight_wrapper(self, name, function):
|
||||
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
||||
26
comfy/ops.py
26
comfy/ops.py
@@ -24,6 +24,32 @@ import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.MATH,
|
||||
]
|
||||
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
import logging
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
@@ -9,6 +10,7 @@ try:
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
logging.warning("Please update pytorch to use native RMSNorm")
|
||||
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
|
||||
@@ -149,7 +149,7 @@ def cleanup_models(conds, models):
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
'''
|
||||
Registers hooks from conds.
|
||||
'''
|
||||
@@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
|
||||
# begin registering hooks
|
||||
registered = comfy.hooks.HookGroup()
|
||||
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||
|
||||
@@ -16,6 +16,7 @@ import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -89,7 +90,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
|
||||
|
||||
hooks = conds.get('hooks', None)
|
||||
control = conds.get('control', None)
|
||||
@@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
|
||||
handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
|
||||
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
|
||||
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
|
||||
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
|
||||
12
comfy/sd.py
12
comfy/sd.py
@@ -47,6 +47,7 @@ import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -771,6 +772,7 @@ class CLIPType(Enum):
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -791,6 +793,7 @@ class TEModel(Enum):
|
||||
T5_XXL_OLD = 8
|
||||
GEMMA_2_2B = 9
|
||||
QWEN25_3B = 10
|
||||
QWEN25_7B = 11
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -812,7 +815,11 @@ def detect_te_model(sd):
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
return TEModel.QWEN25_3B
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@@ -917,6 +924,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||
elif te_model == TEModel.QWEN25_7B:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
|
||||
@@ -19,6 +19,7 @@ import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -1045,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_Camera(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "camera_2.2",
|
||||
"in_dim": 36,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1229,7 +1242,36 @@ class Omnigen2(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
class QwenImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "qwen_image",
|
||||
}
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.15,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.8 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.QwenImage(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -43,6 +43,23 @@ class Qwen25_3BConfig:
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
vocab_size: int = 152064
|
||||
hidden_size: int = 3584
|
||||
intermediate_size: int = 18944
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 28
|
||||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 128000
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
vocab_size: int = 256000
|
||||
@@ -348,6 +365,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_7BVLI_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
71
comfy/text_encoders/qwen_image.py
Normal file
71
comfy/text_encoders/qwen_image.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
import os
|
||||
import torch
|
||||
import numbers
|
||||
|
||||
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||
count_im_start = 0
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 151644 and count_im_start < 2:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
|
||||
if out.shape[1] > (template_end + 3):
|
||||
if tok_pairs[template_end + 1][0] == 872:
|
||||
if tok_pairs[template_end + 2][0] == 198:
|
||||
template_end += 3
|
||||
|
||||
out = out[:, template_end:]
|
||||
|
||||
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||
extra.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class QwenImageTEModel_(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return QwenImageTEModel_
|
||||
@@ -96,6 +96,7 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
qwen_default_lora = "{}.lora_B.default.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
@@ -122,6 +123,10 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
A_name = transformers_lora
|
||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif qwen_default_lora in lora.keys():
|
||||
A_name = qwen_default_lora
|
||||
B_name = "{}.lora_A.default.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
|
||||
86
comfy_api/generate_api_stubs.py
Normal file
86
comfy_api/generate_api_stubs.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to generate .pyi stub files for the synchronous API wrappers.
|
||||
This allows generating stubs without running the full ComfyUI application.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
# Add ComfyUI to path so we can import modules
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
|
||||
from comfy_api.version_list import supported_versions
|
||||
|
||||
|
||||
def generate_stubs_for_module(module_name: str) -> None:
|
||||
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Check if module has ComfyAPISync (the sync wrapper)
|
||||
if hasattr(module, "ComfyAPISync"):
|
||||
# Module already has a sync class
|
||||
api_class = getattr(module, "ComfyAPI", None)
|
||||
sync_class = getattr(module, "ComfyAPISync")
|
||||
|
||||
if api_class:
|
||||
# Generate the stub file
|
||||
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||
logging.info(f"Generated stub file for {module_name}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
|
||||
)
|
||||
|
||||
elif hasattr(module, "ComfyAPI"):
|
||||
# Module only has async API, need to create sync wrapper first
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
|
||||
api_class = getattr(module, "ComfyAPI")
|
||||
sync_class = create_sync_class(api_class)
|
||||
|
||||
# Generate the stub file
|
||||
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||
logging.info(f"Generated stub file for {module_name}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to generate stub for {module_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to generate all API stub files."""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logging.info("Starting stub generation...")
|
||||
|
||||
# Dynamically get module names from supported_versions
|
||||
api_modules = []
|
||||
for api_class in supported_versions:
|
||||
# Extract module name from the class
|
||||
module_name = api_class.__module__
|
||||
if module_name not in api_modules:
|
||||
api_modules.append(module_name)
|
||||
|
||||
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
|
||||
|
||||
# Generate stubs for each module
|
||||
for module_name in api_modules:
|
||||
generate_stubs_for_module(module_name)
|
||||
|
||||
logging.info("Stub generation complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,16 @@
|
||||
from .basic_types import ImageInput, AudioInput
|
||||
from .video_types import VideoInput
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input import (
|
||||
ImageInput,
|
||||
AudioInput,
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
import torch
|
||||
from typing import TypedDict
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input.basic_types import (
|
||||
ImageInput,
|
||||
AudioInput,
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
|
||||
@@ -1,85 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input.video_types import VideoInput
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
loading the entire video into memory.
|
||||
|
||||
Returns:
|
||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||
|
||||
Default implementation creates a BytesIO buffer, but subclasses should
|
||||
override this for better performance when possible.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
self.save_to(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
# Default implementation - subclasses should override for better performance
|
||||
source = self.get_stream_source()
|
||||
with av.open(source, mode="r") as container:
|
||||
return container.format.name
|
||||
__all__ = [
|
||||
"VideoInput",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
||||
|
||||
@@ -1,324 +1,2 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import AudioInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input_impl.video_types import * # noqa: F403
|
||||
|
||||
150
comfy_api/internal/__init__.py
Normal file
150
comfy_api/internal/__init__.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Internal infrastructure for ComfyAPI
|
||||
from .api_registry import (
|
||||
ComfyAPIBase as ComfyAPIBase,
|
||||
ComfyAPIWithVersion as ComfyAPIWithVersion,
|
||||
register_versions as register_versions,
|
||||
get_all_versions as get_all_versions,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||
|
||||
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||
"""
|
||||
if base is None:
|
||||
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||
base = cls.GET_BASE_CLASS()
|
||||
base_attr = getattr(base, name, None)
|
||||
if base_attr is None:
|
||||
return None
|
||||
base_func = base_attr.__func__
|
||||
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
|
||||
if c is base: # reached the placeholder – we're done
|
||||
break
|
||||
if name in c.__dict__: # first class that *defines* the attr
|
||||
func = getattr(c, name).__func__
|
||||
if func is not base_func: # real override
|
||||
return getattr(cls, name) # bound to *cls*
|
||||
return None
|
||||
|
||||
|
||||
class _ComfyNodeInternal:
|
||||
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V1(cls):
|
||||
...
|
||||
|
||||
|
||||
class _NodeOutputInternal:
|
||||
"""Class that all V3-based APIs inherit from for NodeOutput.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
...
|
||||
|
||||
|
||||
def as_pruned_dict(dataclass_obj):
|
||||
'''Return dict of dataclass object with pruned None values.'''
|
||||
return prune_dict(asdict(dataclass_obj))
|
||||
|
||||
def prune_dict(d: dict):
|
||||
return {k: v for k,v in d.items() if v is not None}
|
||||
|
||||
|
||||
def is_class(obj):
|
||||
'''
|
||||
Returns True if is a class type.
|
||||
Returns False if is a class instance.
|
||||
'''
|
||||
return isinstance(obj, type)
|
||||
|
||||
|
||||
def copy_class(cls: type) -> type:
|
||||
'''
|
||||
Copy a class and its attributes.
|
||||
'''
|
||||
if cls is None:
|
||||
return None
|
||||
cls_dict = {
|
||||
k: v for k, v in cls.__dict__.items()
|
||||
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
|
||||
}
|
||||
# new class
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(cls,),
|
||||
cls_dict
|
||||
)
|
||||
# metadata preservation
|
||||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
return new_cls
|
||||
|
||||
|
||||
class classproperty(object):
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
def __get__(self, obj, owner):
|
||||
return self.f(owner)
|
||||
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def shallow_clone_class(cls, new_name=None):
|
||||
'''
|
||||
Shallow clone a class while preserving super() functionality.
|
||||
'''
|
||||
new_name = new_name or f"{cls.__name__}Clone"
|
||||
# Include the original class in the bases to maintain proper inheritance
|
||||
new_bases = (cls,) + cls.__bases__
|
||||
return type(new_name, new_bases, dict(cls.__dict__))
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def lock_class(cls):
|
||||
'''
|
||||
Lock a class so that its top-levelattributes cannot be modified.
|
||||
'''
|
||||
# Locked instance __setattr__
|
||||
def locked_instance_setattr(self, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
|
||||
)
|
||||
# Locked metaclass
|
||||
class LockedMeta(type(cls)):
|
||||
def __setattr__(cls_, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
|
||||
)
|
||||
# Rebuild class with locked behavior
|
||||
locked_dict = dict(cls.__dict__)
|
||||
locked_dict['__setattr__'] = locked_instance_setattr
|
||||
|
||||
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||
|
||||
|
||||
def make_locked_method_func(type_obj, func, class_clone):
|
||||
"""
|
||||
Returns a function that, when called with **inputs, will execute:
|
||||
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||
|
||||
Supports both synchronous and asynchronous methods.
|
||||
"""
|
||||
locked_class = lock_class(class_clone)
|
||||
method = getattr(type_obj, func).__func__
|
||||
|
||||
# Check if the original method is async
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
async def wrapped_async_func(**inputs):
|
||||
return await method(locked_class, **inputs)
|
||||
return wrapped_async_func
|
||||
else:
|
||||
def wrapped_func(**inputs):
|
||||
return method(locked_class, **inputs)
|
||||
return wrapped_func
|
||||
39
comfy_api/internal/api_registry.py
Normal file
39
comfy_api/internal/api_registry.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Type, List, NamedTuple
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from packaging import version as packaging_version
|
||||
|
||||
|
||||
class ComfyAPIBase(ProxiedSingleton):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class ComfyAPIWithVersion(NamedTuple):
|
||||
version: str
|
||||
api_class: Type[ComfyAPIBase]
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> packaging_version.Version:
|
||||
"""
|
||||
Parses a version string into a packaging_version.Version object.
|
||||
Raises ValueError if the version string is invalid.
|
||||
"""
|
||||
if version_str == "latest":
|
||||
return packaging_version.parse("9999999.9999999.9999999")
|
||||
return packaging_version.parse(version_str)
|
||||
|
||||
|
||||
registered_versions: List[ComfyAPIWithVersion] = []
|
||||
|
||||
|
||||
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||
versions.sort(key=lambda x: parse_version(x.version))
|
||||
global registered_versions
|
||||
registered_versions = versions
|
||||
|
||||
|
||||
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||
"""
|
||||
Returns a list of all registered ComfyAPI versions.
|
||||
"""
|
||||
return registered_versions
|
||||
987
comfy_api/internal/async_to_sync.py
Normal file
987
comfy_api/internal/async_to_sync.py
Normal file
@@ -0,0 +1,987 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
"""Tracks types discovered during stub generation for automatic import generation."""
|
||||
|
||||
def __init__(self):
|
||||
self.discovered_types = {} # type_name -> (module, qualname)
|
||||
self.builtin_types = {
|
||||
"Any",
|
||||
"Dict",
|
||||
"List",
|
||||
"Optional",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"Set",
|
||||
"Sequence",
|
||||
"cast",
|
||||
"NamedTuple",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"None",
|
||||
"bytes",
|
||||
"object",
|
||||
"type",
|
||||
"dict",
|
||||
"list",
|
||||
"tuple",
|
||||
"set",
|
||||
}
|
||||
self.already_imported = (
|
||||
set()
|
||||
) # Track types already imported to avoid duplicates
|
||||
|
||||
def track_type(self, annotation):
|
||||
"""Track a type annotation and record its module/import info."""
|
||||
if annotation is None or annotation is type(None):
|
||||
return
|
||||
|
||||
# Skip builtins and typing module types we already import
|
||||
type_name = getattr(annotation, "__name__", None)
|
||||
if type_name and (
|
||||
type_name in self.builtin_types or type_name in self.already_imported
|
||||
):
|
||||
return
|
||||
|
||||
# Get module and qualname
|
||||
module = getattr(annotation, "__module__", None)
|
||||
qualname = getattr(annotation, "__qualname__", type_name or "")
|
||||
|
||||
# Skip types from typing module (they're already imported)
|
||||
if module == "typing":
|
||||
return
|
||||
|
||||
# Skip UnionType and GenericAlias from types module as they're handled specially
|
||||
if module == "types" and type_name in ("UnionType", "GenericAlias"):
|
||||
return
|
||||
|
||||
if module and module not in ["builtins", "__main__"]:
|
||||
# Store the type info
|
||||
if type_name:
|
||||
self.discovered_types[type_name] = (module, qualname)
|
||||
|
||||
def get_imports(self, main_module_name: str) -> list[str]:
|
||||
"""Generate import statements for all discovered types."""
|
||||
imports = []
|
||||
imports_by_module = {}
|
||||
|
||||
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
|
||||
# Skip types from the main module (they're already imported)
|
||||
if main_module_name and module == main_module_name:
|
||||
continue
|
||||
|
||||
if module not in imports_by_module:
|
||||
imports_by_module[module] = []
|
||||
if type_name not in imports_by_module[module]: # Avoid duplicates
|
||||
imports_by_module[module].append(type_name)
|
||||
|
||||
# Generate import statements
|
||||
for module, types in sorted(imports_by_module.items()):
|
||||
if len(types) == 1:
|
||||
imports.append(f"from {module} import {types[0]}")
|
||||
else:
|
||||
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
|
||||
|
||||
return imports
|
||||
|
||||
|
||||
class AsyncToSyncConverter:
|
||||
"""
|
||||
Provides utilities to convert async classes to sync classes with proper type hints.
|
||||
"""
|
||||
|
||||
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
||||
_thread_pool_lock = threading.Lock()
|
||||
_thread_pool_initialized = False
|
||||
|
||||
@classmethod
|
||||
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
|
||||
"""Get or create the shared thread pool with proper thread-safe initialization."""
|
||||
# Fast path - check if already initialized without acquiring lock
|
||||
if cls._thread_pool_initialized:
|
||||
assert cls._thread_pool is not None, "Thread pool should be initialized"
|
||||
return cls._thread_pool
|
||||
|
||||
# Slow path - acquire lock and create pool if needed
|
||||
with cls._thread_pool_lock:
|
||||
if not cls._thread_pool_initialized:
|
||||
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers, thread_name_prefix="async_to_sync_"
|
||||
)
|
||||
cls._thread_pool_initialized = True
|
||||
|
||||
# This should never be None at this point, but add assertion for type checker
|
||||
assert cls._thread_pool is not None
|
||||
return cls._thread_pool
|
||||
|
||||
@classmethod
|
||||
def run_async_in_thread(cls, coro_func, *args, **kwargs):
|
||||
"""
|
||||
Run an async function in a separate thread from the thread pool.
|
||||
Blocks until the async function completes.
|
||||
Properly propagates contextvars between threads and manages event loops.
|
||||
"""
|
||||
# Capture current context - this includes all context variables
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Store the result and any exception that occurs
|
||||
result_container: dict = {"result": None, "exception": None}
|
||||
|
||||
# Function that runs in the thread pool
|
||||
def run_in_thread():
|
||||
# Create new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Create the coroutine within the context
|
||||
async def run_with_context():
|
||||
# The coroutine function might access context variables
|
||||
return await coro_func(*args, **kwargs)
|
||||
|
||||
# Run the coroutine with the captured context
|
||||
# This ensures all context variables are available in the async function
|
||||
result = context.run(loop.run_until_complete, run_with_context())
|
||||
result_container["result"] = result
|
||||
except Exception as e:
|
||||
# Store the exception to re-raise in the calling thread
|
||||
result_container["exception"] = e
|
||||
finally:
|
||||
# Ensure event loop is properly closed to prevent warnings
|
||||
try:
|
||||
# Cancel any remaining tasks
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Run the loop briefly to handle cancellations
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Close the event loop
|
||||
loop.close()
|
||||
|
||||
# Clear the event loop from the thread
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# Submit to thread pool and wait for result
|
||||
thread_pool = cls.get_thread_pool()
|
||||
future = thread_pool.submit(run_in_thread)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Re-raise any exception that occurred in the thread
|
||||
if result_container["exception"] is not None:
|
||||
raise result_container["exception"]
|
||||
|
||||
return result_container["result"]
|
||||
|
||||
@classmethod
|
||||
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a new class with synchronous versions of all async methods.
|
||||
|
||||
Args:
|
||||
async_class: The async class to convert
|
||||
thread_pool_size: Size of thread pool to use
|
||||
|
||||
Returns:
|
||||
A new class with sync versions of all async methods
|
||||
"""
|
||||
sync_class_name = "ComfyAPISyncStub"
|
||||
cls.get_thread_pool(thread_pool_size)
|
||||
|
||||
# Create a proper class with docstrings and proper base classes
|
||||
sync_class_dict = {
|
||||
"__doc__": async_class.__doc__,
|
||||
"__module__": async_class.__module__,
|
||||
"__qualname__": sync_class_name,
|
||||
"__orig_class__": async_class, # Store original class for typing references
|
||||
}
|
||||
|
||||
# Create __init__ method
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._async_instance = async_class(*args, **kwargs)
|
||||
|
||||
# Handle annotated class attributes (like execution: Execution)
|
||||
# Get all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
# For each annotated attribute, check if it needs to be created or wrapped
|
||||
for attr_name, attr_type in all_annotations.items():
|
||||
if hasattr(self._async_instance, attr_name):
|
||||
# Attribute exists on the instance
|
||||
attr = getattr(self._async_instance, attr_name)
|
||||
# Check if this attribute needs a sync wrapper
|
||||
if hasattr(attr, "__class__"):
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
if isinstance(attr, ProxiedSingleton):
|
||||
# Create a sync version of this attribute
|
||||
try:
|
||||
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||
# Create instance of the sync wrapper with the async instance
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = attr
|
||||
setattr(self, attr_name, sync_attr)
|
||||
except Exception:
|
||||
# If we can't create a sync version, keep the original
|
||||
setattr(self, attr_name, attr)
|
||||
else:
|
||||
# Not async, just copy the reference
|
||||
setattr(self, attr_name, attr)
|
||||
else:
|
||||
# Attribute doesn't exist, but is annotated - create it
|
||||
# This handles cases like execution: Execution
|
||||
if isinstance(attr_type, type):
|
||||
# Check if the type is defined as an inner class
|
||||
if hasattr(async_class, attr_type.__name__):
|
||||
inner_class = getattr(async_class, attr_type.__name__)
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
# Create an instance of the inner class
|
||||
try:
|
||||
# For ProxiedSingleton classes, get or create the singleton instance
|
||||
if issubclass(inner_class, ProxiedSingleton):
|
||||
async_instance = inner_class.get_instance()
|
||||
else:
|
||||
async_instance = inner_class()
|
||||
|
||||
# Create sync wrapper
|
||||
sync_attr_class = cls.create_sync_class(inner_class)
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = async_instance
|
||||
setattr(self, attr_name, sync_attr)
|
||||
# Also set on the async instance for consistency
|
||||
setattr(self._async_instance, attr_name, async_instance)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to create instance for {attr_name}: {e}"
|
||||
)
|
||||
|
||||
# Handle other instance attributes that might not be annotated
|
||||
for name, attr in inspect.getmembers(self._async_instance):
|
||||
if name.startswith("_") or hasattr(self, name):
|
||||
continue
|
||||
|
||||
# If attribute is an instance of a class, and that class is defined in the original class
|
||||
# we need to check if it needs a sync wrapper
|
||||
if isinstance(attr, object) and not isinstance(
|
||||
attr, (str, int, float, bool, list, dict, tuple)
|
||||
):
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
if isinstance(attr, ProxiedSingleton):
|
||||
# Create a sync version of this nested class
|
||||
try:
|
||||
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||
# Create instance of the sync wrapper with the async instance
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = attr
|
||||
setattr(self, name, sync_attr)
|
||||
except Exception:
|
||||
# If we can't create a sync version, keep the original
|
||||
setattr(self, name, attr)
|
||||
|
||||
sync_class_dict["__init__"] = __init__
|
||||
|
||||
# Process methods from the async class
|
||||
for name, method in inspect.getmembers(
|
||||
async_class, predicate=inspect.isfunction
|
||||
):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Extract the actual return type from a coroutine
|
||||
if inspect.iscoroutinefunction(method):
|
||||
# Create sync version of async method with proper signature
|
||||
@functools.wraps(method)
|
||||
def sync_method(self, *args, _method_name=name, **kwargs):
|
||||
async_method = getattr(self._async_instance, _method_name)
|
||||
return AsyncToSyncConverter.run_async_in_thread(
|
||||
async_method, *args, **kwargs
|
||||
)
|
||||
|
||||
# Add to the class dict
|
||||
sync_class_dict[name] = sync_method
|
||||
else:
|
||||
# For regular methods, create a proxy method
|
||||
@functools.wraps(method)
|
||||
def proxy_method(self, *args, _method_name=name, **kwargs):
|
||||
method = getattr(self._async_instance, _method_name)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
# Add to the class dict
|
||||
sync_class_dict[name] = proxy_method
|
||||
|
||||
# Handle property access
|
||||
for name, prop in inspect.getmembers(
|
||||
async_class, lambda x: isinstance(x, property)
|
||||
):
|
||||
|
||||
def make_property(name, prop_obj):
|
||||
def getter(self):
|
||||
value = getattr(self._async_instance, name)
|
||||
if inspect.iscoroutinefunction(value):
|
||||
|
||||
def sync_fn(*args, **kwargs):
|
||||
return AsyncToSyncConverter.run_async_in_thread(
|
||||
value, *args, **kwargs
|
||||
)
|
||||
|
||||
return sync_fn
|
||||
return value
|
||||
|
||||
def setter(self, value):
|
||||
setattr(self._async_instance, name, value)
|
||||
|
||||
return property(getter, setter if prop_obj.fset else None)
|
||||
|
||||
sync_class_dict[name] = make_property(name, prop)
|
||||
|
||||
# Create the class
|
||||
sync_class = type(sync_class_name, (object,), sync_class_dict)
|
||||
|
||||
return sync_class
|
||||
|
||||
@classmethod
|
||||
def _format_type_annotation(
|
||||
cls, annotation, type_tracker: Optional[TypeTracker] = None
|
||||
) -> str:
|
||||
"""Convert a type annotation to its string representation for stub files."""
|
||||
if (
|
||||
annotation is inspect.Parameter.empty
|
||||
or annotation is inspect.Signature.empty
|
||||
):
|
||||
return "Any"
|
||||
|
||||
# Handle None type
|
||||
if annotation is type(None):
|
||||
return "None"
|
||||
|
||||
# Track the type if we have a tracker
|
||||
if type_tracker:
|
||||
type_tracker.track_type(annotation)
|
||||
|
||||
# Try using typing.get_origin/get_args for Python 3.8+
|
||||
try:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin is not None:
|
||||
# Track the origin type
|
||||
if type_tracker:
|
||||
type_tracker.track_type(origin)
|
||||
|
||||
# Get the origin name
|
||||
origin_name = getattr(origin, "__name__", str(origin))
|
||||
if "." in origin_name:
|
||||
origin_name = origin_name.split(".")[-1]
|
||||
|
||||
# Special handling for types.UnionType (Python 3.10+ pipe operator)
|
||||
# Convert to old-style Union for compatibility
|
||||
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
|
||||
origin_name = "Union"
|
||||
|
||||
# Format arguments recursively
|
||||
if args:
|
||||
formatted_args = []
|
||||
for arg in args:
|
||||
# Track each type in the union
|
||||
if type_tracker:
|
||||
type_tracker.track_type(arg)
|
||||
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
|
||||
return f"{origin_name}[{', '.join(formatted_args)}]"
|
||||
else:
|
||||
return origin_name
|
||||
except (AttributeError, TypeError):
|
||||
# Fallback for older Python versions or non-generic types
|
||||
pass
|
||||
|
||||
# Handle generic types the old way for compatibility
|
||||
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
|
||||
origin = annotation.__origin__
|
||||
origin_name = (
|
||||
origin.__name__
|
||||
if hasattr(origin, "__name__")
|
||||
else str(origin).split("'")[1]
|
||||
)
|
||||
|
||||
# Format each type argument
|
||||
args = []
|
||||
for arg in annotation.__args__:
|
||||
args.append(cls._format_type_annotation(arg, type_tracker))
|
||||
|
||||
return f"{origin_name}[{', '.join(args)}]"
|
||||
|
||||
# Handle regular types with __name__
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
|
||||
# Handle special module types (like types from typing module)
|
||||
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
|
||||
# For types like typing.Literal, typing.TypedDict, etc.
|
||||
return annotation.__qualname__
|
||||
|
||||
# Last resort: string conversion with cleanup
|
||||
type_str = str(annotation)
|
||||
|
||||
# Clean up common patterns more robustly
|
||||
if type_str.startswith("<class '") and type_str.endswith("'>"):
|
||||
type_str = type_str[8:-2] # Remove "<class '" and "'>"
|
||||
|
||||
# Remove module prefixes for common modules
|
||||
for prefix in ["typing.", "builtins.", "types."]:
|
||||
if type_str.startswith(prefix):
|
||||
type_str = type_str[len(prefix) :]
|
||||
|
||||
# Handle special cases
|
||||
if type_str in ("_empty", "inspect._empty"):
|
||||
return "None"
|
||||
|
||||
# Fix NoneType (this should rarely be needed now)
|
||||
if type_str == "NoneType":
|
||||
return "None"
|
||||
|
||||
return type_str
|
||||
|
||||
@classmethod
|
||||
def _extract_coroutine_return_type(cls, annotation):
|
||||
"""Extract the actual return type from a Coroutine annotation."""
|
||||
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
|
||||
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
|
||||
return annotation.__args__[2]
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def _format_parameter_default(cls, default_value) -> str:
|
||||
"""Format a parameter's default value for stub files."""
|
||||
if default_value is inspect.Parameter.empty:
|
||||
return ""
|
||||
elif default_value is None:
|
||||
return " = None"
|
||||
elif isinstance(default_value, bool):
|
||||
return f" = {default_value}"
|
||||
elif default_value == {}:
|
||||
return " = {}"
|
||||
elif default_value == []:
|
||||
return " = []"
|
||||
else:
|
||||
return f" = {default_value}"
|
||||
|
||||
@classmethod
|
||||
def _format_method_parameters(
|
||||
cls,
|
||||
sig: inspect.Signature,
|
||||
skip_self: bool = True,
|
||||
type_hints: Optional[dict] = None,
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> str:
|
||||
"""Format method parameters for stub files."""
|
||||
params = []
|
||||
if type_hints is None:
|
||||
type_hints = {}
|
||||
|
||||
for i, (param_name, param) in enumerate(sig.parameters.items()):
|
||||
if i == 0 and param_name == "self" and skip_self:
|
||||
params.append("self")
|
||||
else:
|
||||
# Get type annotation from type hints if available, otherwise from signature
|
||||
annotation = type_hints.get(param_name, param.annotation)
|
||||
type_str = cls._format_type_annotation(annotation, type_tracker)
|
||||
|
||||
# Get default value
|
||||
default_str = cls._format_parameter_default(param.default)
|
||||
|
||||
# Combine parameter parts
|
||||
if annotation is inspect.Parameter.empty:
|
||||
params.append(f"{param_name}: Any{default_str}")
|
||||
else:
|
||||
params.append(f"{param_name}: {type_str}{default_str}")
|
||||
|
||||
return ", ".join(params)
|
||||
|
||||
@classmethod
|
||||
def _generate_method_signature(
|
||||
cls,
|
||||
method_name: str,
|
||||
method,
|
||||
is_async: bool = False,
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> str:
|
||||
"""Generate a complete method signature for stub files."""
|
||||
sig = inspect.signature(method)
|
||||
|
||||
# Try to get evaluated type hints to resolve string annotations
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
type_hints = get_type_hints(method)
|
||||
except Exception:
|
||||
# Fallback to empty dict if we can't get type hints
|
||||
type_hints = {}
|
||||
|
||||
# For async methods, extract the actual return type
|
||||
return_annotation = type_hints.get('return', sig.return_annotation)
|
||||
if is_async and inspect.iscoroutinefunction(method):
|
||||
return_annotation = cls._extract_coroutine_return_type(return_annotation)
|
||||
|
||||
# Format parameters with type hints
|
||||
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
|
||||
|
||||
# Format return type
|
||||
return_type = cls._format_type_annotation(return_annotation, type_tracker)
|
||||
if return_annotation is inspect.Signature.empty:
|
||||
return_type = "None"
|
||||
|
||||
return f"def {method_name}({params_str}) -> {return_type}: ..."
|
||||
|
||||
@classmethod
|
||||
def _generate_imports(
|
||||
cls, async_class: Type, type_tracker: TypeTracker
|
||||
) -> list[str]:
|
||||
"""Generate import statements for the stub file."""
|
||||
imports = []
|
||||
|
||||
# Add standard typing imports
|
||||
imports.append(
|
||||
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
|
||||
)
|
||||
|
||||
# Add imports from the original module
|
||||
if async_class.__module__ != "builtins":
|
||||
module = inspect.getmodule(async_class)
|
||||
additional_types = []
|
||||
|
||||
if module:
|
||||
# Check if module has __all__ defined
|
||||
module_all = getattr(module, "__all__", None)
|
||||
|
||||
for name, obj in sorted(inspect.getmembers(module)):
|
||||
if isinstance(obj, type):
|
||||
# Skip if __all__ is defined and this name isn't in it
|
||||
# unless it's already been tracked as used in type annotations
|
||||
if module_all is not None and name not in module_all:
|
||||
# Check if this type was actually used in annotations
|
||||
if name not in type_tracker.discovered_types:
|
||||
continue
|
||||
|
||||
# Check for NamedTuple
|
||||
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
|
||||
additional_types.append(name)
|
||||
# Mark as already imported
|
||||
type_tracker.already_imported.add(name)
|
||||
# Check for Enum
|
||||
elif issubclass(obj, Enum) and name != "Enum":
|
||||
additional_types.append(name)
|
||||
# Mark as already imported
|
||||
type_tracker.already_imported.add(name)
|
||||
|
||||
if additional_types:
|
||||
type_imports = ", ".join([async_class.__name__] + additional_types)
|
||||
imports.append(f"from {async_class.__module__} import {type_imports}")
|
||||
else:
|
||||
imports.append(
|
||||
f"from {async_class.__module__} import {async_class.__name__}"
|
||||
)
|
||||
|
||||
# Add imports for all discovered types
|
||||
# Pass the main module name to avoid duplicate imports
|
||||
imports.extend(
|
||||
type_tracker.get_imports(main_module_name=async_class.__module__)
|
||||
)
|
||||
|
||||
# Add base module import if needed
|
||||
if hasattr(inspect.getmodule(async_class), "__name__"):
|
||||
module_name = inspect.getmodule(async_class).__name__
|
||||
if "." in module_name:
|
||||
base_module = module_name.split(".")[0]
|
||||
# Only add if not already importing from it
|
||||
if not any(imp.startswith(f"from {base_module}") for imp in imports):
|
||||
imports.append(f"import {base_module}")
|
||||
|
||||
return imports
|
||||
|
||||
@classmethod
|
||||
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
# Look for class attributes that are classes
|
||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||
if isinstance(attr, type) and not name.startswith("_"):
|
||||
class_attributes.append((name, attr))
|
||||
elif (
|
||||
hasattr(async_class, "__annotations__")
|
||||
and name in async_class.__annotations__
|
||||
):
|
||||
annotation = async_class.__annotations__[name]
|
||||
if isinstance(annotation, type):
|
||||
class_attributes.append((name, annotation))
|
||||
|
||||
return class_attributes
|
||||
|
||||
@classmethod
|
||||
def _generate_inner_class_stub(
|
||||
cls,
|
||||
name: str,
|
||||
attr: Type,
|
||||
indent: str = " ",
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> list[str]:
|
||||
"""Generate stub for an inner class."""
|
||||
stub_lines = []
|
||||
stub_lines.append(f"{indent}class {name}Sync:")
|
||||
|
||||
# Add docstring if available
|
||||
if hasattr(attr, "__doc__") and attr.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
|
||||
)
|
||||
|
||||
# Add __init__ if it exists
|
||||
if hasattr(attr, "__init__"):
|
||||
try:
|
||||
init_method = getattr(attr, "__init__")
|
||||
init_sig = inspect.signature(init_method)
|
||||
|
||||
# Try to get type hints
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
init_hints = get_type_hints(init_method)
|
||||
except Exception:
|
||||
init_hints = {}
|
||||
|
||||
# Format parameters
|
||||
params_str = cls._format_method_parameters(
|
||||
init_sig, type_hints=init_hints, type_tracker=type_tracker
|
||||
)
|
||||
# Add __init__ docstring if available (before the method)
|
||||
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(
|
||||
init_method.__doc__, f"{indent} "
|
||||
)
|
||||
)
|
||||
stub_lines.append(
|
||||
f"{indent} def __init__({params_str}) -> None: ..."
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
stub_lines.append(
|
||||
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
|
||||
)
|
||||
|
||||
# Add methods to the inner class
|
||||
has_methods = False
|
||||
for method_name, method in sorted(
|
||||
inspect.getmembers(attr, predicate=inspect.isfunction)
|
||||
):
|
||||
if method_name.startswith("_"):
|
||||
continue
|
||||
|
||||
has_methods = True
|
||||
try:
|
||||
# Add method docstring if available (before the method signature)
|
||||
if method.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
|
||||
)
|
||||
|
||||
method_sig = cls._generate_method_signature(
|
||||
method_name, method, is_async=True, type_tracker=type_tracker
|
||||
)
|
||||
stub_lines.append(f"{indent} {method_sig}")
|
||||
except (ValueError, TypeError):
|
||||
stub_lines.append(
|
||||
f"{indent} def {method_name}(self, *args, **kwargs): ..."
|
||||
)
|
||||
|
||||
if not has_methods:
|
||||
stub_lines.append(f"{indent} pass")
|
||||
|
||||
return stub_lines
|
||||
|
||||
@classmethod
|
||||
def _format_docstring_for_stub(
|
||||
cls, docstring: str, indent: str = " "
|
||||
) -> list[str]:
|
||||
"""Format a docstring for inclusion in a stub file with proper indentation."""
|
||||
if not docstring:
|
||||
return []
|
||||
|
||||
# First, dedent the docstring to remove any existing indentation
|
||||
dedented = textwrap.dedent(docstring).strip()
|
||||
|
||||
# Split into lines
|
||||
lines = dedented.split("\n")
|
||||
|
||||
# Build the properly indented docstring
|
||||
result = []
|
||||
result.append(f'{indent}"""')
|
||||
|
||||
for line in lines:
|
||||
if line.strip(): # Non-empty line
|
||||
result.append(f"{indent}{line}")
|
||||
else: # Empty line
|
||||
result.append("")
|
||||
|
||||
result.append(f'{indent}"""')
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
|
||||
"""Post-process stub content to fix any remaining issues."""
|
||||
processed = []
|
||||
|
||||
for line in stub_content:
|
||||
# Skip processing imports
|
||||
if line.startswith(("from ", "import ")):
|
||||
processed.append(line)
|
||||
continue
|
||||
|
||||
# Fix method signatures missing return types
|
||||
if (
|
||||
line.strip().startswith("def ")
|
||||
and line.strip().endswith(": ...")
|
||||
and ") -> " not in line
|
||||
):
|
||||
# Add -> None for methods without return annotation
|
||||
line = line.replace(": ...", " -> None: ...")
|
||||
|
||||
processed.append(line)
|
||||
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||
"""
|
||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||
"""
|
||||
try:
|
||||
# Only generate stub if we can determine module path
|
||||
if async_class.__module__ == "__main__":
|
||||
return
|
||||
|
||||
module = inspect.getmodule(async_class)
|
||||
if not module:
|
||||
return
|
||||
|
||||
module_path = module.__file__
|
||||
if not module_path:
|
||||
return
|
||||
|
||||
# Create stub file path in a 'generated' subdirectory
|
||||
module_dir = os.path.dirname(module_path)
|
||||
stub_dir = os.path.join(module_dir, "generated")
|
||||
|
||||
# Ensure the generated directory exists
|
||||
os.makedirs(stub_dir, exist_ok=True)
|
||||
|
||||
module_name = os.path.basename(module_path)
|
||||
if module_name.endswith(".py"):
|
||||
module_name = module_name[:-3]
|
||||
|
||||
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
|
||||
|
||||
# Create a type tracker for this stub generation
|
||||
type_tracker = TypeTracker()
|
||||
|
||||
stub_content = []
|
||||
|
||||
# We'll generate imports after processing all methods to capture all types
|
||||
# Leave a placeholder for imports
|
||||
imports_placeholder_index = len(stub_content)
|
||||
stub_content.append("") # Will be replaced with imports later
|
||||
|
||||
# Class definition
|
||||
stub_content.append(f"class {sync_class.__name__}:")
|
||||
|
||||
# Docstring
|
||||
if async_class.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(async_class.__doc__, " ")
|
||||
)
|
||||
|
||||
# Generate __init__
|
||||
try:
|
||||
init_method = async_class.__init__
|
||||
init_signature = inspect.signature(init_method)
|
||||
|
||||
# Try to get type hints for __init__
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
init_hints = get_type_hints(init_method)
|
||||
except Exception:
|
||||
init_hints = {}
|
||||
|
||||
# Format parameters
|
||||
params_str = cls._format_method_parameters(
|
||||
init_signature, type_hints=init_hints, type_tracker=type_tracker
|
||||
)
|
||||
# Add __init__ docstring if available (before the method)
|
||||
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(init_method.__doc__, " ")
|
||||
)
|
||||
stub_content.append(f" def __init__({params_str}) -> None: ...")
|
||||
except (ValueError, TypeError):
|
||||
stub_content.append(
|
||||
" def __init__(self, *args, **kwargs) -> None: ..."
|
||||
)
|
||||
|
||||
stub_content.append("") # Add newline after __init__
|
||||
|
||||
# Get class attributes
|
||||
class_attributes = cls._get_class_attributes(async_class)
|
||||
|
||||
# Generate inner classes
|
||||
for name, attr in class_attributes:
|
||||
inner_class_stub = cls._generate_inner_class_stub(
|
||||
name, attr, type_tracker=type_tracker
|
||||
)
|
||||
stub_content.extend(inner_class_stub)
|
||||
stub_content.append("") # Add newline after the inner class
|
||||
|
||||
# Add methods to the main class
|
||||
processed_methods = set() # Keep track of methods we've processed
|
||||
for name, method in sorted(
|
||||
inspect.getmembers(async_class, predicate=inspect.isfunction)
|
||||
):
|
||||
if name.startswith("_") or name in processed_methods:
|
||||
continue
|
||||
|
||||
processed_methods.add(name)
|
||||
|
||||
try:
|
||||
method_sig = cls._generate_method_signature(
|
||||
name, method, is_async=True, type_tracker=type_tracker
|
||||
)
|
||||
|
||||
# Add docstring if available (before the method signature for proper formatting)
|
||||
if method.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(method.__doc__, " ")
|
||||
)
|
||||
|
||||
stub_content.append(f" {method_sig}")
|
||||
|
||||
stub_content.append("") # Add newline after each method
|
||||
|
||||
except (ValueError, TypeError):
|
||||
# If we can't get the signature, just add a simple stub
|
||||
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
|
||||
stub_content.append("") # Add newline
|
||||
|
||||
# Add properties
|
||||
for name, prop in sorted(
|
||||
inspect.getmembers(async_class, lambda x: isinstance(x, property))
|
||||
):
|
||||
stub_content.append(" @property")
|
||||
stub_content.append(f" def {name}(self) -> Any: ...")
|
||||
if prop.fset:
|
||||
stub_content.append(f" @{name}.setter")
|
||||
stub_content.append(
|
||||
f" def {name}(self, value: Any) -> None: ..."
|
||||
)
|
||||
stub_content.append("") # Add newline after each property
|
||||
|
||||
# Add placeholders for the nested class instances
|
||||
# Check the actual attribute names from class annotations and attributes
|
||||
attribute_mappings = {}
|
||||
|
||||
# First check annotations for typed attributes (including from parent classes)
|
||||
# Collect all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||
for class_name, class_type in class_attributes:
|
||||
# If the class type matches the annotated type
|
||||
if (
|
||||
attr_type == class_type
|
||||
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
|
||||
or (isinstance(attr_type, str) and attr_type == class_name)
|
||||
):
|
||||
attribute_mappings[class_name] = attr_name
|
||||
|
||||
# Remove the extra checking - annotations should be sufficient
|
||||
|
||||
# Add the attribute declarations with proper names
|
||||
for class_name, class_type in class_attributes:
|
||||
# Check if there's a mapping from annotation
|
||||
attr_name = attribute_mappings.get(class_name, class_name)
|
||||
# Use the annotation name if it exists, even if the attribute doesn't exist yet
|
||||
# This is because the attribute might be created at runtime
|
||||
stub_content.append(f" {attr_name}: {class_name}Sync")
|
||||
|
||||
stub_content.append("") # Add a final newline
|
||||
|
||||
# Now generate imports with all discovered types
|
||||
imports = cls._generate_imports(async_class, type_tracker)
|
||||
|
||||
# Deduplicate imports while preserving order
|
||||
seen = set()
|
||||
unique_imports = []
|
||||
for imp in imports:
|
||||
if imp not in seen:
|
||||
seen.add(imp)
|
||||
unique_imports.append(imp)
|
||||
else:
|
||||
logging.warning(f"Duplicate import detected: {imp}")
|
||||
|
||||
# Replace the placeholder with actual imports
|
||||
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
|
||||
unique_imports
|
||||
)
|
||||
|
||||
# Post-process stub content
|
||||
stub_content = cls._post_process_stub_content(stub_content)
|
||||
|
||||
# Write stub file
|
||||
with open(sync_stub_path, "w") as f:
|
||||
f.write("\n".join(stub_content))
|
||||
|
||||
logging.info(f"Generated stub file: {sync_stub_path}")
|
||||
|
||||
except Exception as e:
|
||||
# If stub generation fails, log the error but don't break the main functionality
|
||||
logging.error(
|
||||
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a sync version of an async class
|
||||
|
||||
Args:
|
||||
async_class: The async class to convert
|
||||
thread_pool_size: Size of thread pool to use
|
||||
|
||||
Returns:
|
||||
A new class with sync versions of all async methods
|
||||
"""
|
||||
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)
|
||||
33
comfy_api/internal/singleton.py
Normal file
33
comfy_api/internal/singleton.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
class SingletonMetaclass(type):
|
||||
T = TypeVar("T", bound="SingletonMetaclass")
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cls._instances[cls]
|
||||
|
||||
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||
assert cls not in SingletonMetaclass._instances, (
|
||||
"Cannot inject instance after first instantiation"
|
||||
)
|
||||
SingletonMetaclass._instances[cls] = instance
|
||||
|
||||
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||
"""
|
||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||
"""
|
||||
if cls not in SingletonMetaclass._instances:
|
||||
SingletonMetaclass._instances[cls] = super(
|
||||
SingletonMetaclass, cls
|
||||
).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
124
comfy_api/latest/__init__.py
Normal file
124
comfy_api/latest/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
from comfy.cli_args import args
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ComfyAPI_latest(ComfyAPIBase):
|
||||
VERSION = "latest"
|
||||
STABLE = False
|
||||
|
||||
class Execution(ProxiedSingleton):
|
||||
async def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: str | None = None,
|
||||
preview_image: Image.Image | ImageInput | None = None,
|
||||
ignore_size_limit: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
executing_context = get_executing_context()
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
if node_id is None:
|
||||
raise ValueError("node_id must be provided if not in executing context")
|
||||
|
||||
# Convert preview_image to PreviewImageTuple if needed
|
||||
to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image
|
||||
if to_display is not None:
|
||||
# First convert to PIL Image if needed
|
||||
if isinstance(to_display, ImageInput):
|
||||
# Convert ImageInput (torch.Tensor) to PIL Image
|
||||
# Handle tensor shape [B, H, W, C] -> get first image if batch
|
||||
tensor = to_display
|
||||
if len(tensor.shape) == 4:
|
||||
tensor = tensor[0]
|
||||
|
||||
# Convert to numpy array and scale to 0-255
|
||||
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
|
||||
to_display = Image.fromarray(image_np)
|
||||
|
||||
if isinstance(to_display, Image.Image):
|
||||
# Detect image format from PIL Image
|
||||
image_format = to_display.format if to_display.format else "JPEG"
|
||||
# Use None for preview_size if ignore_size_limit is True
|
||||
preview_size = None if ignore_size_limit else args.preview_size
|
||||
to_display = (image_format, to_display, preview_size)
|
||||
|
||||
get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
execution: Execution
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
"""
|
||||
Returns a list of nodes that this extension provides.
|
||||
"""
|
||||
|
||||
class Input:
|
||||
Image = ImageInput
|
||||
Audio = AudioInput
|
||||
Mask = MaskInput
|
||||
Latent = LatentInput
|
||||
Video = VideoInput
|
||||
|
||||
class InputImpl:
|
||||
VideoFromFile = VideoFromFile
|
||||
VideoFromComponents = VideoFromComponents
|
||||
|
||||
class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoComponents = VideoComponents
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
10
comfy_api/latest/_input/__init__.py
Normal file
10
comfy_api/latest/_input/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
42
comfy_api/latest/_input/basic_types.py
Normal file
42
comfy_api/latest/_input/basic_types.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from typing import TypedDict, List, Optional
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
MaskInput = torch.Tensor
|
||||
"""
|
||||
A mask in format [B, H, W] where B is the batch size
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
|
||||
class LatentInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing latent input.
|
||||
"""
|
||||
|
||||
samples: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
|
||||
H is the height, and W is the width.
|
||||
"""
|
||||
|
||||
noise_mask: Optional[MaskInput]
|
||||
"""
|
||||
Optional noise mask tensor in the same format as samples.
|
||||
"""
|
||||
|
||||
batch_index: Optional[List[int]]
|
||||
85
comfy_api/latest/_input/video_types.py
Normal file
85
comfy_api/latest/_input/video_types.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
loading the entire video into memory.
|
||||
|
||||
Returns:
|
||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||
|
||||
Default implementation creates a BytesIO buffer, but subclasses should
|
||||
override this for better performance when possible.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
self.save_to(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
# Default implementation - subclasses should override for better performance
|
||||
source = self.get_stream_source()
|
||||
with av.open(source, mode="r") as container:
|
||||
return container.format.name
|
||||
7
comfy_api/latest/_input_impl/__init__.py
Normal file
7
comfy_api/latest/_input_impl/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
||||
324
comfy_api/latest/_input_impl/video_types.py
Normal file
324
comfy_api/latest/_input_impl/video_types.py
Normal file
@@ -0,0 +1,324 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
1618
comfy_api/latest/_io.py
Normal file
1618
comfy_api/latest/_io.py
Normal file
File diff suppressed because it is too large
Load Diff
72
comfy_api/latest/_resources.py
Normal file
72
comfy_api/latest/_resources.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class ResourceKey(ABC):
|
||||
Type = Any
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
class TorchDictFolderFilename(ResourceKey):
|
||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||
Type = dict[str, torch.Tensor]
|
||||
def __init__(self, folder_name: str, file_name: str):
|
||||
self.folder_name = folder_name
|
||||
self.file_name = file_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.folder_name, self.file_name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TorchDictFolderFilename):
|
||||
return False
|
||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.folder_name} -> {self.file_name}"
|
||||
|
||||
class Resources(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
pass
|
||||
|
||||
class ResourcesLocal(Resources):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.local_resources: dict[ResourceKey, Any] = {}
|
||||
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
cached = self.local_resources.get(key, None)
|
||||
if cached is not None:
|
||||
logging.info(f"Using cached resource '{key}'")
|
||||
return cached
|
||||
logging.info(f"Loading resource '{key}'")
|
||||
to_return = None
|
||||
if isinstance(key, TorchDictFolderFilename):
|
||||
if default is ...:
|
||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||
else:
|
||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||
if full_path is not None:
|
||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||
|
||||
if to_return is not None:
|
||||
self.local_resources[key] = to_return
|
||||
return to_return
|
||||
if default is not ...:
|
||||
return default
|
||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||
|
||||
|
||||
class _RESOURCES:
|
||||
ResourceKey = ResourceKey
|
||||
TorchDictFolderFilename = TorchDictFolderFilename
|
||||
Resources = Resources
|
||||
ResourcesLocal = ResourcesLocal
|
||||
463
comfy_api/latest/_ui.py
Normal file
463
comfy_api/latest/_ui.py
Normal file
@@ -0,0 +1,463 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
try:
|
||||
import torchaudio
|
||||
TORCH_AUDIO_AVAILABLE = True
|
||||
except:
|
||||
TORCH_AUDIO_AVAILABLE = False
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self["filename"]
|
||||
|
||||
@property
|
||||
def subfolder(self) -> str:
|
||||
return self["subfolder"]
|
||||
|
||||
@property
|
||||
def type(self) -> FolderType:
|
||||
return FolderType(self["type"])
|
||||
|
||||
|
||||
class SavedImages(_UIOutput):
|
||||
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
self.is_animated = is_animated
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
data = {"images": self.results}
|
||||
if self.is_animated:
|
||||
data["animated"] = (True,)
|
||||
return data
|
||||
|
||||
|
||||
class SavedAudios(_UIOutput):
|
||||
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||
def __init__(self, results: list[SavedResult]):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.results}
|
||||
|
||||
|
||||
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.input:
|
||||
return folder_paths.get_input_directory()
|
||||
if folder_type == FolderType.output:
|
||||
return folder_paths.get_output_directory()
|
||||
return folder_paths.get_temp_directory()
|
||||
|
||||
|
||||
class ImageSaveHelper:
|
||||
"""A helper class with static methods to handle image saving and metadata."""
|
||||
|
||||
@staticmethod
|
||||
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
|
||||
"""Converts a single torch tensor to a PIL Image."""
|
||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
|
||||
@staticmethod
|
||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
"prompt".encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
x.encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||
"""Creates EXIF metadata bytes for WebP images."""
|
||||
exif_data = pil_image.getexif()
|
||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||
return exif_data
|
||||
if cls.hidden.prompt is not None:
|
||||
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
inital_exif_tag = 0x010F # EXIF 0x010f = Make
|
||||
for key, value in cls.hidden.extra_pnginfo.items():
|
||||
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
|
||||
inital_exif_tag -= 1
|
||||
return exif_data
|
||||
|
||||
@staticmethod
|
||||
def save_images(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||
) -> list[SavedResult]:
|
||||
"""Saves a batch of images as individual PNG files."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
results = []
|
||||
metadata = ImageSaveHelper._create_png_metadata(cls)
|
||||
for batch_number, image_tensor in enumerate(images):
|
||||
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
"""Saves a batch of images and returns a UI object for the node output."""
|
||||
return SavedImages(
|
||||
ImageSaveHelper.save_images(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_png(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated PNG."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
save_path = os.path.join(full_output_folder, file)
|
||||
pil_images[0].save(
|
||||
save_path,
|
||||
pnginfo=metadata,
|
||||
compress_level=compress_level,
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_png_ui(
|
||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedImages:
|
||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_png(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_webp(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated WebP."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[0].save(
|
||||
os.path.join(full_output_folder, file),
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
exif=pil_exif,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_webp_ui(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedImages:
|
||||
"""Saves an animated WebP and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_webp(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
|
||||
class AudioSaveHelper:
|
||||
"""A helper class with static methods to handle audio saving and metadata."""
|
||||
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
@staticmethod
|
||||
def save_audio(
|
||||
audio: dict,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
format: str = "flac",
|
||||
quality: str = "128k",
|
||||
) -> list[SavedResult]:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type)
|
||||
)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata and cls is not None:
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
results = []
|
||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(AudioSaveHelper._OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
if not TORCH_AUDIO_AVAILABLE:
|
||||
raise Exception("torchaudio is not available; cannot resample audio.")
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format=format)
|
||||
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: # format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_audio_ui(
|
||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
) -> SavedAudios:
|
||||
"""Save and instantly wrap for UI."""
|
||||
return SavedAudios(
|
||||
AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
format=format,
|
||||
quality=quality,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
compress_level=1,
|
||||
)
|
||||
self.animated = animated
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"images": self.values,
|
||||
"animated": (self.animated,)
|
||||
}
|
||||
|
||||
|
||||
class PreviewMask(PreviewImage):
|
||||
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
|
||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
super().__init__(preview, animated, cls, **kwargs)
|
||||
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
format="flac",
|
||||
quality="128k",
|
||||
)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.values}
|
||||
|
||||
|
||||
class PreviewVideo(_UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
|
||||
def as_dict(self):
|
||||
return {"images": self.values, "animated": (True,)}
|
||||
|
||||
|
||||
class PreviewUI3D(_UIOutput):
|
||||
def __init__(self, model_file, camera_info, **kwargs):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
def __init__(self, value: str, **kwargs):
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return {"text": (self.value,)}
|
||||
|
||||
|
||||
class _UI:
|
||||
SavedResult = SavedResult
|
||||
SavedImages = SavedImages
|
||||
SavedAudios = SavedAudios
|
||||
ImageSaveHelper = ImageSaveHelper
|
||||
AudioSaveHelper = AudioSaveHelper
|
||||
PreviewImage = PreviewImage
|
||||
PreviewMask = PreviewMask
|
||||
PreviewAudio = PreviewAudio
|
||||
PreviewVideo = PreviewVideo
|
||||
PreviewUI3D = PreviewUI3D
|
||||
PreviewText = PreviewText
|
||||
8
comfy_api/latest/_util/__init__.py
Normal file
8
comfy_api/latest/_util/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
||||
52
comfy_api/latest/_util/video_types.py
Normal file
52
comfy_api/latest/_util/video_types.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
8
comfy_api/util.py
Normal file
8
comfy_api/util.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
"VideoCodec",
|
||||
"VideoContainer",
|
||||
"VideoComponents",
|
||||
]
|
||||
@@ -1,7 +1,7 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
|
||||
@@ -1,51 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util.video_types import (
|
||||
VideoContainer,
|
||||
VideoCodec,
|
||||
VideoComponents,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
||||
|
||||
42
comfy_api/v0_0_1/__init__.py
Normal file
42
comfy_api/v0_0_1/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from comfy_api.v0_0_2 import (
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
Input as Input_v0_0_2,
|
||||
InputImpl as InputImpl_v0_0_2,
|
||||
Types as Types_v0_0_2,
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
|
||||
|
||||
# This version only exists to serve as a template for future version adapters.
|
||||
# There is no reason anyone should ever use it.
|
||||
class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2):
|
||||
VERSION = "0.0.1"
|
||||
STABLE = True
|
||||
|
||||
class Input(Input_v0_0_2):
|
||||
pass
|
||||
|
||||
class InputImpl(InputImpl_v0_0_2):
|
||||
pass
|
||||
|
||||
class Types(Types_v0_0_2):
|
||||
pass
|
||||
|
||||
ComfyAPI = ComfyAPIAdapter_v0_0_1
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||
|
||||
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
]
|
||||
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
45
comfy_api/v0_0_2/__init__.py
Normal file
45
comfy_api/v0_0_2/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from comfy_api.latest import (
|
||||
ComfyAPI_latest,
|
||||
Input as Input_latest,
|
||||
InputImpl as InputImpl_latest,
|
||||
Types as Types_latest,
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
VERSION = "0.0.2"
|
||||
STABLE = False
|
||||
|
||||
|
||||
class Input(Input_latest):
|
||||
pass
|
||||
|
||||
|
||||
class InputImpl(InputImpl_latest):
|
||||
pass
|
||||
|
||||
|
||||
class Types(Types_latest):
|
||||
pass
|
||||
|
||||
|
||||
ComfyAPI = ComfyAPIAdapter_v0_0_2
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
12
comfy_api/version_list.py
Normal file
12
comfy_api/version_list.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from typing import List, Type
|
||||
|
||||
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||
ComfyAPI_latest,
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
ComfyAPIAdapter_v0_0_1,
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -21,7 +22,6 @@ from server import PromptServer
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
@@ -30,7 +30,7 @@ from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
@@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = download_url_to_bytesio(video_url, timeout)
|
||||
video_io = await download_url_to_bytesio(video_url, timeout)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
@@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
return s
|
||||
|
||||
|
||||
def validate_and_cast_response(
|
||||
async def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
@@ -86,35 +86,24 @@ def validate_and_cast_response(
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
for image_data in data:
|
||||
image_url = image_data.url
|
||||
b64_data = image_data.b64_json
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
for img_data in data:
|
||||
img_bytes: bytes
|
||||
if img_data.b64_json:
|
||||
img_bytes = base64.b64decode(img_data.b64_json)
|
||||
elif img_data.url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||
async with session.get(img_data.url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("Failed to download generated image")
|
||||
img_bytes = await resp.read()
|
||||
else:
|
||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||
|
||||
if not image_url and not b64_data:
|
||||
raise ValueError("No image was generated in the response")
|
||||
|
||||
if b64_data:
|
||||
img_data = base64.b64decode(b64_data)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {image_url}", node_id
|
||||
)
|
||||
img_response = requests.get(image_url, timeout=timeout)
|
||||
if img_response.status_code != 200:
|
||||
raise ValueError("Failed to download the image")
|
||||
img = Image.open(io.BytesIO(img_response.content))
|
||||
|
||||
img = img.convert("RGBA")
|
||||
|
||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)
|
||||
|
||||
# Add to list of tensors
|
||||
image_tensors.append(img_tensor)
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
@@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
@@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
response = requests.get(url, stream=True, timeout=timeout)
|
||||
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(response.content)
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.get(url) as resp:
|
||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(await resp.read())
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
@@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
@@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def upload_file_to_comfyapi(
|
||||
async def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
upload_mime_type: Optional[str],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -354,7 +345,10 @@ def upload_file_to_comfyapi(
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
@@ -366,12 +360,8 @@ def upload_file_to_comfyapi(
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = operation.execute()
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, file_bytes_io, content_type=upload_mime_type
|
||||
)
|
||||
upload_response.raise_for_status()
|
||||
|
||||
response: UploadResponse = await operation.execute()
|
||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
||||
return response.download_url
|
||||
|
||||
|
||||
@@ -399,7 +389,7 @@ def video_to_base64_string(
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
async def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
@@ -439,9 +429,7 @@ def upload_video_to_comfyapi(
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return upload_file_to_comfyapi(
|
||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
||||
)
|
||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
@@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio(
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
async def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
@@ -527,7 +515,7 @@ def upload_audio_to_comfyapi(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
@@ -544,7 +532,7 @@ def audio_to_base64_string(
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@@ -561,55 +549,15 @@ def upload_images_to_comfyapi(
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
idx_image = 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_length = 1
|
||||
if is_batch:
|
||||
batch_length = image.shape[0]
|
||||
while True:
|
||||
curr_image = image
|
||||
if len(image.shape) > 3:
|
||||
curr_image = image[idx_image]
|
||||
# get BytesIO version of image
|
||||
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
|
||||
# first, request upload/download urls from comfy API
|
||||
if not mime_type:
|
||||
request_object = UploadRequest(file_name=img_binary.name)
|
||||
else:
|
||||
request_object = UploadRequest(
|
||||
file_name=img_binary.name, content_type=mime_type
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response = operation.execute()
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, img_binary, content_type=mime_type
|
||||
)
|
||||
# verify success
|
||||
try:
|
||||
upload_response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise ValueError(f"Could not upload one or more images: {e}") from e
|
||||
# add download_url to list
|
||||
download_urls.append(response.download_url)
|
||||
|
||||
idx_image += 1
|
||||
# stop uploading additional files if done
|
||||
if is_batch and max_images > 0:
|
||||
if idx_image >= max_images:
|
||||
break
|
||||
if idx_image >= batch_length:
|
||||
break
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
|
||||
2656
comfy_api_nodes/apis/__init__.py
generated
2656
comfy_api_nodes/apis/__init__.py
generated
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
|
||||
@@ -127,7 +127,7 @@ class TripoTextToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
|
||||
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
|
||||
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union, Optional
|
||||
@@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
@@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor):
|
||||
return mask
|
||||
|
||||
|
||||
def handle_bfl_synchronous_operation(
|
||||
async def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||
return _poll_until_generated(
|
||||
response_api: BFLFluxProGenerateResponse = await operation.execute()
|
||||
return await _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
|
||||
def _poll_until_generated(
|
||||
async def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
@@ -66,55 +67,56 @@ def _poll_until_generated(
|
||||
retry_404_seconds = 2
|
||||
retry_202_seconds = 2
|
||||
retry_pending_seconds = 1
|
||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
response = requests.Session().send(request.prepare())
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
img_response = requests.get(img_url)
|
||||
return process_image_response(img_response)
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
time.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status_code == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
time.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status_code == 202:
|
||||
time.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
async with session.get(polling_url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
async with session.get(img_url) as img_resp:
|
||||
return process_image_response(await img_resp.content.read())
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
await asyncio.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
await asyncio.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status == 202:
|
||||
await asyncio.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||
@@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_upsampling,
|
||||
@@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
@@ -301,7 +303,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiModel,
|
||||
@@ -330,7 +332,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_endpoint(model),
|
||||
request=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
|
||||
@@ -212,7 +212,7 @@ V3_RESOLUTIONS= [
|
||||
"1536x640"
|
||||
]
|
||||
|
||||
def download_and_process_images(image_urls):
|
||||
async def download_and_process_images(image_urls):
|
||||
"""Helper function to download and process multiple images from URLs"""
|
||||
|
||||
# Initialize list to store image tensors
|
||||
@@ -220,7 +220,7 @@ def download_and_process_images(image_urls):
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
@@ -328,7 +328,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
@@ -367,7 +367,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@@ -378,7 +378,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
@@ -487,7 +487,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
@@ -543,7 +543,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@@ -554,7 +554,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
@@ -653,7 +653,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
image=None,
|
||||
@@ -774,7 +774,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@@ -785,7 +785,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -109,7 +109,7 @@ class KlingApiError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
@@ -117,7 +117,7 @@ def poll_until_finished(
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
KlingTaskStatus.succeed.value,
|
||||
@@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def video_result_to_node_output(
|
||||
async def video_result_to_node_output(
|
||||
video: KlingVideoResult,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
"""Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
|
||||
return (
|
||||
download_url_to_video_output(video.url),
|
||||
await download_url_to_video_output(str(video.url)),
|
||||
str(video.id),
|
||||
str(video.duration),
|
||||
)
|
||||
|
||||
|
||||
def image_result_to_node_output(
|
||||
async def image_result_to_node_output(
|
||||
images: list[KlingImageResult],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -297,9 +297,9 @@ def image_result_to_node_output(
|
||||
If multiple images are returned, they will be stacked along the batch dimension.
|
||||
"""
|
||||
if len(images) == 1:
|
||||
return download_url_to_image_tensor(images[0].url)
|
||||
return await download_url_to_image_tensor(str(images[0].url))
|
||||
else:
|
||||
return torch.cat([download_url_to_image_tensor(image.url) for image in images])
|
||||
return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images])
|
||||
|
||||
|
||||
class KlingNodeBase(ComfyNodeABC):
|
||||
@@ -467,10 +467,10 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Text to Video Node"
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingText2VideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||
@@ -483,7 +483,7 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -519,17 +519,17 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
|
||||
task_id = task_creation_response.data.task_id
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
@@ -581,7 +581,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
|
||||
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -591,7 +591,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1,
|
||||
cfg_scale=cfg_scale,
|
||||
mode=KlingVideoGenMode.std,
|
||||
@@ -670,10 +670,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Image to Video Node"
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingImage2VideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||
@@ -686,7 +686,7 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -733,17 +733,17 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
@@ -798,7 +798,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
|
||||
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -809,7 +809,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||
start_frame=start_frame,
|
||||
cfg_scale=cfg_scale,
|
||||
@@ -897,7 +897,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
|
||||
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
@@ -912,7 +912,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
|
||||
mode
|
||||
]
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model_name=model_name,
|
||||
@@ -964,10 +964,10 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoExtendResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||
@@ -980,7 +980,7 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@@ -1006,17 +1006,17 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingVideoEffectsBase(KlingNodeBase):
|
||||
@@ -1025,10 +1025,10 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoEffectsResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||
@@ -1041,7 +1041,7 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
dual_character: bool,
|
||||
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
|
||||
@@ -1084,17 +1084,17 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
@@ -1142,7 +1142,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "duration")
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image_left: torch.Tensor,
|
||||
image_right: torch.Tensor,
|
||||
@@ -1153,7 +1153,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
video, _, duration = super().api_call(
|
||||
video, _, duration = await super().api_call(
|
||||
dual_character=True,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@@ -1208,7 +1208,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
|
||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
effect_scene: KlingSingleImageEffectsScene,
|
||||
@@ -1217,7 +1217,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
dual_character=False,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@@ -1253,11 +1253,11 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
|
||||
)
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingLipSyncResponse:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state."""
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||
@@ -1270,7 +1270,7 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: Optional[AudioInput] = None,
|
||||
@@ -1287,12 +1287,12 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
self.validate_lip_sync_video(video)
|
||||
|
||||
# Upload video to Comfy API and get download URL
|
||||
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
@@ -1319,17 +1319,17 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
@@ -1357,7 +1357,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: AudioInput,
|
||||
@@ -1365,7 +1365,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
video=video,
|
||||
audio=audio,
|
||||
voice_language=voice_language,
|
||||
@@ -1469,7 +1469,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
text: str,
|
||||
@@ -1479,7 +1479,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
**kwargs,
|
||||
):
|
||||
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
video=video,
|
||||
text=text,
|
||||
voice_language=voice_language,
|
||||
@@ -1533,10 +1533,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
|
||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVirtualTryOnResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||
@@ -1549,7 +1549,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
human_image: torch.Tensor,
|
||||
cloth_image: torch.Tensor,
|
||||
@@ -1572,17 +1572,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
return (image_result_to_node_output(images),)
|
||||
return (await image_result_to_node_output(images),)
|
||||
|
||||
|
||||
class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
@@ -1655,13 +1655,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
|
||||
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
node_id: Optional[str] = None,
|
||||
) -> KlingImageGenerationsResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||
@@ -1674,7 +1674,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
model_name: KlingImageGenModelName,
|
||||
prompt: str,
|
||||
@@ -1690,7 +1690,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
):
|
||||
self.validate_prompt(prompt, negative_prompt)
|
||||
|
||||
if image is not None:
|
||||
if image is None:
|
||||
image_type = None
|
||||
elif model_name == KlingImageGenModelName.kling_v1:
|
||||
raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
|
||||
else:
|
||||
image = tensor_to_base64_string(image)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
@@ -1714,17 +1718,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
return (image_result_to_node_output(images),)
|
||||
return (await image_result_to_node_output(images),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
@@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = self._convert_luma_refs(
|
||||
api_image_ref = await self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = self._convert_style_image(
|
||||
api_style_ref = await self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
@@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return (img,)
|
||||
|
||||
def _convert_luma_refs(
|
||||
async def _convert_luma_refs(
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
luma_urls.append(download_urls[0])
|
||||
@@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
break
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
def _convert_style_image(
|
||||
async def _convert_style_image(
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
@@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
@@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return (img,)
|
||||
|
||||
|
||||
@@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
@@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
@@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
@@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
@@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
def _convert_to_keyframes(
|
||||
async def _convert_to_keyframes(
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
@@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
|
||||
@@ -86,7 +86,7 @@ class MinimaxTextToVideoNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_video(
|
||||
async def generate_video(
|
||||
self,
|
||||
prompt_text,
|
||||
seed=0,
|
||||
@@ -104,12 +104,12 @@ class MinimaxTextToVideoNode:
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ class MinimaxTextToVideoNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = video_generate_operation.execute()
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
@@ -151,7 +151,7 @@ class MinimaxTextToVideoNode:
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = video_generate_operation.execute()
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
@@ -167,7 +167,7 @@ class MinimaxTextToVideoNode:
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = file_retrieve_operation.execute()
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
@@ -182,7 +182,7 @@ class MinimaxTextToVideoNode:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = download_url_to_bytesio(file_url)
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
import random
|
||||
import torch
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
get_image_dimensions,
|
||||
@@ -95,14 +94,14 @@ def get_video_url_from_response(response) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
@@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
||||
supported_resolutions = {
|
||||
(1920, 1080), (1080, 1920), (1152, 1152),
|
||||
(1536, 1152), (1152, 1536)
|
||||
(1920, 1080),
|
||||
(1080, 1920),
|
||||
(1152, 1152),
|
||||
(1536, 1152),
|
||||
(1152, 1536),
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
supported_list = ", ".join(
|
||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_container_format(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(
|
||||
f"Only MP4 container format supported. Got: {container_format}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
@@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
return video
|
||||
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
@@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||
target_frames = (
|
||||
estimated_frames // 16
|
||||
) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
@@ -394,10 +403,10 @@ class BaseMoonvalleyVideoNode:
|
||||
else:
|
||||
return control_map["Motion Transfer"]
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
@@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode:
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
@@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode:
|
||||
"tooltip": "Resolution of the output video",
|
||||
},
|
||||
),
|
||||
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
|
||||
"prompt_adherence": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"guidance_scale",
|
||||
default=7.0,
|
||||
default=10.0,
|
||||
step=1,
|
||||
min=1,
|
||||
max=20,
|
||||
@@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode:
|
||||
IO.INT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"seed",
|
||||
default=random.randint(0, 2**32 - 1),
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display="number",
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=True,
|
||||
),
|
||||
"steps": model_field_to_node_input(
|
||||
IO.INT,
|
||||
@@ -507,7 +514,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
RETURN_NAMES = ("video",)
|
||||
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
image = kwargs.get("image", None)
|
||||
@@ -532,8 +539,10 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||
image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||
)
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
@@ -549,14 +558,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
@@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
|
||||
multiline=True
|
||||
IO.STRING,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display="number",
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=False,
|
||||
),
|
||||
"prompt_adherence": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"guidance_scale",
|
||||
default=10.0,
|
||||
step=1,
|
||||
min=1,
|
||||
max=20,
|
||||
),
|
||||
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
@@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
"optional": {
|
||||
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
|
||||
"video": (
|
||||
IO.VIDEO,
|
||||
{
|
||||
"default": "",
|
||||
"multiline": False,
|
||||
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||
},
|
||||
),
|
||||
"control_type": (
|
||||
["Motion Transfer", "Pose Transfer"],
|
||||
{"default": "Motion Transfer"},
|
||||
@@ -602,17 +640,24 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
"max": 100,
|
||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
||||
},
|
||||
)
|
||||
}
|
||||
),
|
||||
"image": model_field_to_node_input(
|
||||
IO.IMAGE,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
"image_url",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
video = kwargs.get("video")
|
||||
image = kwargs.get("image", None)
|
||||
|
||||
if not video:
|
||||
raise MoonvalleyApiError("video is required")
|
||||
@@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
video_url = ""
|
||||
if video:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||
video_url = await upload_video_to_comfyapi(
|
||||
validated_video, auth_kwargs=kwargs
|
||||
)
|
||||
mime_type = "image/png"
|
||||
|
||||
if not image is None:
|
||||
validate_input_image(image, with_frame_conditioning=True)
|
||||
image_url = await upload_images_to_comfyapi(
|
||||
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
|
||||
)
|
||||
control_type = kwargs.get("control_type")
|
||||
motion_intensity = kwargs.get("motion_intensity")
|
||||
|
||||
@@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
if control_type == "Motion Transfer" and motion_intensity is not None:
|
||||
control_params['motion_intensity'] = motion_intensity
|
||||
control_params["motion_intensity"] = motion_intensity
|
||||
|
||||
inference_params=MoonvalleyVideoToVideoInferenceParams(
|
||||
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
seed=kwargs.get("seed"),
|
||||
control_params=control_params
|
||||
control_params=control_params,
|
||||
)
|
||||
|
||||
control = self.parseControlParameter(control_type)
|
||||
@@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
request.image_url = image_url if not image is None else None
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -658,15 +712,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
|
||||
return (video,)
|
||||
|
||||
@@ -688,21 +742,21 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
del input_types["optional"][param]
|
||||
return input_types
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||
|
||||
inference_params=MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=128,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
)
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=128,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
)
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
@@ -717,15 +771,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@@ -233,9 +233,9 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -311,7 +311,7 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@@ -343,9 +343,9 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -446,7 +446,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
@@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = f"image_{i}.png"
|
||||
|
||||
img_binaries.append(img_binary)
|
||||
if batch_size == 1:
|
||||
files.append(("image", img_binary))
|
||||
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
else:
|
||||
files.append(("image[]", img_binary))
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if image is None:
|
||||
@@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
mask_img_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
mask_binary = mask_img_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
files.append(("mask", mask_binary))
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
@@ -537,9 +530,9 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@@ -623,7 +616,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
|
||||
DESCRIPTION = "Generate text responses from an OpenAI model."
|
||||
|
||||
def get_result_response(
|
||||
async def get_result_response(
|
||||
self,
|
||||
response_id: str,
|
||||
include: Optional[list[Includable]] = None,
|
||||
@@ -639,7 +632,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
creation above for more information.
|
||||
|
||||
"""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
||||
method=HttpMethod.GET,
|
||||
@@ -784,7 +777,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
|
||||
self.history[session_id] = new_history
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
persist_context: bool,
|
||||
@@ -815,7 +808,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
previous_response_id = None
|
||||
|
||||
# Create response
|
||||
create_response = SynchronousOperation(
|
||||
create_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=RESPONSES_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
@@ -848,7 +841,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
response_id = create_response.id
|
||||
|
||||
# Get result output
|
||||
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
result_response = await self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
output_text = self.parse_output_text_from_response(result_response)
|
||||
|
||||
# Update history
|
||||
|
||||
@@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
FUNCTION = "api_call"
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def poll_for_task_status(
|
||||
async def poll_for_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
node_id=node_id,
|
||||
estimated_duration=60
|
||||
)
|
||||
return polling_operation.execute()
|
||||
return await polling_operation.execute()
|
||||
|
||||
def execute_task(
|
||||
async def execute_task(
|
||||
self,
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
"""
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
if not is_valid_initial_response(initial_response):
|
||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
||||
final_response = await self.poll_for_task_status(task_id, auth_kwargs)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
@@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
video_url = str(final_response.url)
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
|
||||
return (download_url_to_video_output(video_url),)
|
||||
return (await download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
@@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
@@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
@@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
@@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaScenesV2_2(PikaNodeBase):
|
||||
@@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
@@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikAdditionsNode(PikaNodeBase):
|
||||
@@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
@@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
]
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
@@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaSwapsNode(PikaNodeBase):
|
||||
@@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
@@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
|
||||
]
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
||||
}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
@@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaffectsNode(PikaNodeBase):
|
||||
@@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
@@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
@@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
@@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
pika_files = [
|
||||
(
|
||||
"keyFrames",
|
||||
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
|
||||
),
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
|
||||
@@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
import torch
|
||||
import requests
|
||||
import aiohttp
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def get_video_url_from_response(
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
@@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(
|
||||
@@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
@@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
@@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
@@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -37,7 +37,7 @@ from io import BytesIO
|
||||
from PIL import UnidentifiedImageError
|
||||
|
||||
|
||||
def handle_recraft_file_request(
|
||||
async def handle_recraft_file_request(
|
||||
image: torch.Tensor,
|
||||
path: str,
|
||||
mask: torch.Tensor=None,
|
||||
@@ -71,13 +71,13 @@ def handle_recraft_file_request(
|
||||
auth_kwargs=auth_kwargs,
|
||||
multipart_parser=recraft_multipart_parser,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
all_bytesio = []
|
||||
if response.image is not None:
|
||||
all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout))
|
||||
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
|
||||
else:
|
||||
for data in response.data:
|
||||
all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout))
|
||||
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
|
||||
|
||||
return all_bytesio
|
||||
|
||||
@@ -395,7 +395,7 @@ class RecraftTextToImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
size: str,
|
||||
@@ -439,7 +439,7 @@ class RecraftTextToImageNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
images = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
@@ -451,7 +451,7 @@ class RecraftTextToImageNode:
|
||||
f"Result URL: {urls_string}", unique_id
|
||||
)
|
||||
image = bytesio_to_image_tensor(
|
||||
download_url_to_bytesio(data.url, timeout=1024)
|
||||
await download_url_to_bytesio(data.url, timeout=1024)
|
||||
)
|
||||
if len(image.shape) < 4:
|
||||
image = image.unsqueeze(0)
|
||||
@@ -538,7 +538,7 @@ class RecraftImageToImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -578,7 +578,7 @@ class RecraftImageToImageNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/imageToImage",
|
||||
request=request,
|
||||
@@ -654,7 +654,7 @@ class RecraftImageInpaintingNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@@ -690,7 +690,7 @@ class RecraftImageInpaintingNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
mask=mask[i:i+1],
|
||||
path="/proxy/recraft/images/inpaint",
|
||||
@@ -779,7 +779,7 @@ class RecraftTextToVectorNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
substyle: str,
|
||||
@@ -821,7 +821,7 @@ class RecraftTextToVectorNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
svg_data = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
@@ -831,7 +831,7 @@ class RecraftTextToVectorNode:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {' '.join(urls)}", unique_id
|
||||
)
|
||||
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
|
||||
svg_data.append(await download_url_to_bytesio(data.url, timeout=1024))
|
||||
|
||||
return (SVG(svg_data),)
|
||||
|
||||
@@ -861,7 +861,7 @@ class RecraftVectorizeImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@@ -870,7 +870,7 @@ class RecraftVectorizeImageNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/vectorize",
|
||||
auth_kwargs=kwargs,
|
||||
@@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/replaceBackground",
|
||||
request=request,
|
||||
@@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/removeBackground",
|
||||
auth_kwargs=kwargs,
|
||||
@@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path=self.RECRAFT_PATH,
|
||||
auth_kwargs=kwargs,
|
||||
|
||||
@@ -9,11 +9,10 @@ from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
import folder_paths as comfy_paths
|
||||
import requests
|
||||
import aiohttp
|
||||
import os
|
||||
import datetime
|
||||
import shutil
|
||||
import time
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
@@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse):
|
||||
return hasattr(response, "error")
|
||||
|
||||
|
||||
|
||||
class Rodin3DAPI:
|
||||
"""
|
||||
Generate 3D Assets using Rodin API
|
||||
@@ -123,8 +121,8 @@ class Rodin3DAPI:
|
||||
else:
|
||||
return "Generating"
|
||||
|
||||
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images == None:
|
||||
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images is None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) >= 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
@@ -155,7 +153,7 @@ class Rodin3DAPI:
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if create_task_error(response):
|
||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||
@@ -168,7 +166,7 @@ class Rodin3DAPI:
|
||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||
return task_uuid, subscription_key
|
||||
|
||||
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
|
||||
path = "/proxy/rodin/api/v2/status"
|
||||
|
||||
@@ -191,11 +189,9 @@ class Rodin3DAPI:
|
||||
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||
|
||||
return poll_operation.execute()
|
||||
return await poll_operation.execute()
|
||||
|
||||
|
||||
|
||||
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
|
||||
path = "/proxy/rodin/api/v2/download"
|
||||
@@ -212,53 +208,59 @@ class Rodin3DAPI:
|
||||
auth_kwargs=kwargs
|
||||
)
|
||||
|
||||
return operation.execute()
|
||||
return await operation.execute()
|
||||
|
||||
def GetQualityAndMode(self, PolyCount):
|
||||
if PolyCount == "200K-Triangle":
|
||||
def get_quality_mode(self, poly_count):
|
||||
if poly_count == "200K-Triangle":
|
||||
mesh_mode = "Raw"
|
||||
quality = "medium"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
if PolyCount == "4K-Quad":
|
||||
if poly_count == "4K-Quad":
|
||||
quality = "extra-low"
|
||||
elif PolyCount == "8K-Quad":
|
||||
elif poly_count == "8K-Quad":
|
||||
quality = "low"
|
||||
elif PolyCount == "18K-Quad":
|
||||
elif poly_count == "18K-Quad":
|
||||
quality = "medium"
|
||||
elif PolyCount == "50K-Quad":
|
||||
elif poly_count == "50K-Quad":
|
||||
quality = "high"
|
||||
else:
|
||||
quality = "medium"
|
||||
|
||||
return mesh_mode, quality
|
||||
|
||||
def DownLoadFiles(self, Url_List):
|
||||
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(Save_path, exist_ok=True)
|
||||
async def download_files(self, url_list):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
for Item in Url_List.list:
|
||||
url = Item.url
|
||||
file_name = Item.name
|
||||
file_path = os.path.join(Save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
time.sleep(2)
|
||||
else:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in url_list.list:
|
||||
url = i.url
|
||||
file_name = i.name
|
||||
file_path = os.path.join(save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||
f.write(chunk)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logging.info(
|
||||
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
||||
file_path,
|
||||
max_retries,
|
||||
)
|
||||
|
||||
return model_file_path
|
||||
|
||||
@@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Detail(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Smooth(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Sketch(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
material_type = "PBR"
|
||||
quality = "medium"
|
||||
mesh_mode = "Quad"
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
task_uuid, subscription_key = await self.create_generate_task(
|
||||
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
|
||||
)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
@@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool:
|
||||
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
@@ -115,7 +115,7 @@ def poll_until_finished(
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: (response.status.value),
|
||||
status_extractor=lambda response: response.status.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
@@ -167,11 +167,11 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
@@ -183,7 +183,7 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def generate_video(
|
||||
async def generate_video(
|
||||
self,
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
@@ -200,15 +200,15 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
final_response = self.get_response(task_id, auth_kwargs, node_id)
|
||||
final_response = await self.get_response(task_id, auth_kwargs, node_id)
|
||||
self.validate_response(final_response)
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return (download_url_to_video_output(video_url),)
|
||||
return (await download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
@@ -250,7 +250,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
@@ -265,7 +265,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
@@ -274,7 +274,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -333,7 +333,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
@@ -348,7 +348,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
@@ -357,7 +357,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
|
||||
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
@@ -437,7 +437,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
@@ -455,7 +455,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
|
||||
# Upload images
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
@@ -464,7 +464,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@@ -543,11 +543,11 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
@@ -559,7 +559,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
@@ -574,7 +574,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_input_image(reference_image)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
@@ -605,19 +605,19 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
# Poll for completion
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
self.validate_response(final_response)
|
||||
|
||||
# Download and return image
|
||||
image_url = get_image_url_from_task_status(final_response)
|
||||
return (download_url_to_image_tensor(image_url),)
|
||||
return (await download_url_to_image_tensor(image_url),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@@ -124,7 +124,7 @@ class StabilityStableImageUltraNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
@@ -163,7 +163,7 @@ class StabilityStableImageUltraNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
@@ -257,7 +257,7 @@ class StabilityStableImageSD_3_5Node:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
@@ -302,7 +302,7 @@ class StabilityStableImageSD_3_5Node:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
@@ -374,7 +374,7 @@ class StabilityUpscaleConservativeNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
@@ -403,7 +403,7 @@ class StabilityUpscaleConservativeNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
@@ -480,7 +480,7 @@ class StabilityUpscaleCreativeNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
@@ -512,7 +512,7 @@ class StabilityUpscaleCreativeNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@@ -527,7 +527,7 @@ class StabilityUpscaleCreativeNode:
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = operation.execute()
|
||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
@@ -563,8 +563,7 @@ class StabilityUpscaleFastNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor,
|
||||
**kwargs):
|
||||
async def api_call(self, image: torch.Tensor, **kwargs):
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
@@ -583,7 +582,7 @@ class StabilityUpscaleFastNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
|
||||
@@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import (
|
||||
)
|
||||
|
||||
|
||||
def upload_image_to_tripo(image, **kwargs):
|
||||
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
async def upload_image_to_tripo(image, **kwargs):
|
||||
urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
@@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
raise RuntimeError(f"Failed to get model url from response: {response}")
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
kwargs: dict[str, str],
|
||||
response: TripoTaskResponse,
|
||||
) -> tuple[str, str]:
|
||||
@@ -57,7 +57,7 @@ def poll_until_finished(
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
||||
task_id = response.data.task_id
|
||||
response_poll = PollingOperation(
|
||||
response_poll = await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
@@ -80,7 +80,7 @@ def poll_until_finished(
|
||||
).execute()
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = download_url_to_bytesio(url)
|
||||
bytesio = await download_url_to_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
@@ -88,6 +88,7 @@ def poll_until_finished(
|
||||
return model_file, task_id
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
|
||||
class TripoTextToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
||||
@@ -126,11 +127,11 @@ class TripoTextToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
raise RuntimeError("Prompt is required")
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -155,7 +156,8 @@ class TripoTextToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoImageToModelNode:
|
||||
"""
|
||||
@@ -195,12 +197,12 @@ class TripoImageToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
raise RuntimeError("Image is required")
|
||||
tripo_file = upload_image_to_tripo(image, **kwargs)
|
||||
response = SynchronousOperation(
|
||||
tripo_file = await upload_image_to_tripo(image, **kwargs)
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -225,7 +227,8 @@ class TripoImageToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoMultiviewToModelNode:
|
||||
"""
|
||||
@@ -267,7 +270,7 @@ class TripoMultiviewToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
images = []
|
||||
@@ -282,11 +285,11 @@ class TripoMultiviewToModelNode:
|
||||
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
||||
image_ = image_dict[image_name]
|
||||
if image_ is not None:
|
||||
tripo_file = upload_image_to_tripo(image_, **kwargs)
|
||||
tripo_file = await upload_image_to_tripo(image_, **kwargs)
|
||||
images.append(tripo_file)
|
||||
else:
|
||||
images.append(TripoFileEmptyReference())
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -309,7 +312,8 @@ class TripoMultiviewToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoTextureNode:
|
||||
@classmethod
|
||||
@@ -340,8 +344,8 @@ class TripoTextureNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 80
|
||||
|
||||
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -358,7 +362,7 @@ class TripoTextureNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRefineNode:
|
||||
@@ -387,8 +391,8 @@ class TripoRefineNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 240
|
||||
|
||||
def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -400,7 +404,7 @@ class TripoRefineNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRigNode:
|
||||
@@ -425,8 +429,8 @@ class TripoRigNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 180
|
||||
|
||||
def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -440,7 +444,8 @@ class TripoRigNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRetargetNode:
|
||||
@classmethod
|
||||
@@ -475,8 +480,8 @@ class TripoRetargetNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -491,7 +496,8 @@ class TripoRetargetNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoConversionNode:
|
||||
@classmethod
|
||||
@@ -529,10 +535,10 @@ class TripoConversionNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
if not original_model_task_id:
|
||||
raise RuntimeError("original_model_task_id is required")
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@@ -549,7 +555,8 @@ class TripoConversionNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripoTextToModelNode": TripoTextToModelNode,
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
Veo2GenVidRequest,
|
||||
Veo2GenVidResponse,
|
||||
Veo2GenVidPollRequest,
|
||||
Veo2GenVidPollResponse
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -35,7 +35,7 @@ def convert_image_to_base64(image: torch.Tensor):
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
|
||||
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
@@ -130,6 +130,14 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-2.0-generate-001"],
|
||||
"default": "veo-2.0-generate-001",
|
||||
"tooltip": "Veo 2 model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
@@ -141,10 +149,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
|
||||
API_NODE = True
|
||||
|
||||
def generate_video(
|
||||
async def generate_video(
|
||||
self,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
@@ -154,6 +162,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
model="veo-2.0-generate-001",
|
||||
generate_audio=False,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -188,23 +198,26 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
parameters["negativePrompt"] = negative_prompt
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
# Only add generateAudio for Veo 3 models
|
||||
if "veo-3.0" in model:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/generate",
|
||||
path=f"/proxy/veo/{model}/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidRequest,
|
||||
response_model=Veo2GenVidResponse
|
||||
request_model=VeoGenVidRequest,
|
||||
response_model=VeoGenVidResponse
|
||||
),
|
||||
request=Veo2GenVidRequest(
|
||||
request=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||
@@ -223,16 +236,16 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/poll",
|
||||
path=f"/proxy/veo/{model}/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidPollRequest,
|
||||
response_model=Veo2GenVidPollResponse
|
||||
request_model=VeoGenVidPollRequest,
|
||||
response_model=VeoGenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=Veo2GenVidPollRequest(
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
@@ -243,7 +256,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = poll_operation.execute()
|
||||
poll_response = await poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
@@ -268,7 +281,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
video_data = None
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
@@ -278,9 +290,9 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
video_url = video.gcsUri
|
||||
video_response = requests.get(video_url)
|
||||
video_data = video_response.content
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(video.gcsUri) as video_response:
|
||||
video_data = await video_response.content.read()
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
@@ -298,11 +310,64 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
# Register the node
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo 3 API.
|
||||
|
||||
Supported models:
|
||||
- veo-3.0-generate-001
|
||||
- veo-3.0-fast-generate-001
|
||||
|
||||
This node extends the base Veo node with Veo 3 specific features including
|
||||
audio generation and fixed 8-second duration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
parent_input = super().INPUT_TYPES()
|
||||
|
||||
# Update model options for Veo 3
|
||||
parent_input["optional"]["model"] = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
"default": "veo-3.0-generate-001",
|
||||
"tooltip": "Veo 3 model to use for video generation",
|
||||
},
|
||||
)
|
||||
|
||||
# Add generateAudio parameter
|
||||
parent_input["optional"]["generate_audio"] = (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
|
||||
}
|
||||
)
|
||||
|
||||
# Update duration constraints for Veo 3 (only 8 seconds supported)
|
||||
parent_input["optional"]["duration_seconds"] = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 8,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
},
|
||||
)
|
||||
|
||||
return parent_input
|
||||
|
||||
|
||||
# Register the nodes
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
|
||||
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
|
||||
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
|
||||
}
|
||||
|
||||
@@ -4,9 +4,12 @@ from typing import Type, Literal
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||
ExecutionBlocker = ExecutionBlocker
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
@@ -294,21 +297,3 @@ class ExecutionList(TopologicalSort):
|
||||
del blocked_by[node_id]
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
return list(blocked_by.keys())
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
@@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix):
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import TypedDict, Dict, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict, Dict, Optional, Tuple
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
@@ -10,6 +12,7 @@ if TYPE_CHECKING:
|
||||
from protocol import BinaryEventTypes
|
||||
from comfy_api import feature_flags
|
||||
|
||||
PreviewImageTuple = Tuple[str, Image.Image, Optional[int]]
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
@@ -52,7 +55,7 @@ class ProgressHandler(ABC):
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
"""Called when a node's progress is updated"""
|
||||
pass
|
||||
@@ -103,7 +106,7 @@ class CLIProgressHandler(ProgressHandler):
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
# Handle case where start_handler wasn't called
|
||||
if node_id not in self.progress_bars:
|
||||
@@ -196,7 +199,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
@@ -231,7 +234,6 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
|
||||
class ProgressRegistry:
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
@@ -285,7 +287,7 @@ class ProgressRegistry:
|
||||
handler.start_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def update_progress(
|
||||
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
|
||||
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None
|
||||
) -> None:
|
||||
"""Update progress for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
@@ -317,7 +319,7 @@ class ProgressRegistry:
|
||||
handler.reset()
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry = None
|
||||
global_progress_registry: ProgressRegistry | None = None
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
global global_progress_registry
|
||||
|
||||
@@ -346,6 +346,24 @@ class LoadAudio:
|
||||
return "Invalid audio file: {}".format(audio)
|
||||
return True
|
||||
|
||||
class RecordAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, audio):
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentAudio": EmptyLatentAudio,
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
@@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
"RecordAudio": RecordAudio,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
"RecordAudio": "Record Audio",
|
||||
}
|
||||
|
||||
89
comfy_extras/nodes_context_windows.py
Normal file
89
comfy_extras/nodes_context_windows.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.context_windows
|
||||
import nodes
|
||||
|
||||
|
||||
class ContextWindowsManualNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ContextWindowsManual",
|
||||
display_name="Context Windows (Manual)",
|
||||
category="context",
|
||||
description="Manually set context windows.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], tooltip="The stride of the context window."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
|
||||
context_length=context_length,
|
||||
context_overlap=context_overlap,
|
||||
context_stride=context_stride,
|
||||
closed_loop=closed_loop,
|
||||
dim=dim)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "WanContextWindowsManual"
|
||||
schema.display_name = "WAN Context Windows (Manual)"
|
||||
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
||||
schema.inputs = [
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."),
|
||||
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], tooltip="The stride of the context window."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
||||
|
||||
|
||||
class ContextWindowsExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ContextWindowsManualNode,
|
||||
WanContextWindowsManualNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return ContextWindowsExtension()
|
||||
@@ -100,9 +100,28 @@ class FluxKontextImageScale:
|
||||
return (image, )
|
||||
|
||||
|
||||
class FluxKontextMultiReferenceLatentMethod:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING", ),
|
||||
"reference_latents_method": (("offset", "index"), ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
|
||||
def append(self, conditioning, reference_latents_method):
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
||||
return (c, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||
"FluxGuidance": FluxGuidance,
|
||||
"FluxDisableGuidance": FluxDisableGuidance,
|
||||
"FluxKontextImageScale": FluxKontextImageScale,
|
||||
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
|
||||
}
|
||||
|
||||
33
comfy_extras/nodes_memory_reserve.py
Normal file
33
comfy_extras/nodes_memory_reserve.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
|
||||
class MemoryReserveNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ReserveAdditionalMemory",
|
||||
display_name="Reserve Additional Memory",
|
||||
description="Adds additional expected memory usage for the model, in gigabytes.",
|
||||
category="advanced/debug/model",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add memory reserve to."),
|
||||
io.Float.Input("memory_reserve_gb", min=0.0, default=0.0, max=2048.0, step=0.1, tooltip="The additional expected memory usage for the model, in gigabytes."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with the additional memory reserve."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, memory_reserve_gb: float) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.add_model_memory_reserve(memory_reserve_gb)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class MemoryReserveExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
MemoryReserveNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return MemoryReserveExtension()
|
||||
@@ -314,6 +314,29 @@ class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBl
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embeds."] = argument
|
||||
arg_dict["img_in."] = argument
|
||||
arg_dict["txt_norm."] = argument
|
||||
arg_dict["txt_in."] = argument
|
||||
arg_dict["time_text_embed."] = argument
|
||||
|
||||
for i in range(60):
|
||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["proj_out."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@@ -329,4 +352,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
"ModelMergeQwenImage": ModelMergeQwenImage,
|
||||
}
|
||||
|
||||
@@ -8,9 +8,7 @@ import json
|
||||
from typing import Optional, Literal
|
||||
from fractions import Fraction
|
||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||
from comfy_api.input import ImageInput, AudioInput, VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest import Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM:
|
||||
@@ -91,8 +89,8 @@ class SaveVideo(ComfyNodeABC):
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||
"format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
@@ -108,7 +106,7 @@ class SaveVideo(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
|
||||
def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
@@ -127,7 +125,7 @@ class SaveVideo(ComfyNodeABC):
|
||||
metadata["prompt"] = prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=format,
|
||||
@@ -163,9 +161,9 @@ class CreateVideo(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Create a video from images."
|
||||
|
||||
def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None):
|
||||
return (VideoFromComponents(
|
||||
VideoComponents(
|
||||
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
|
||||
return (InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(
|
||||
images=images,
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
@@ -187,7 +185,7 @@ class GetVideoComponents(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||
|
||||
def get_components(self, video: VideoInput):
|
||||
def get_components(self, video: Input.Video):
|
||||
components = video.get_components()
|
||||
|
||||
return (components.images, components.audio, float(components.frame_rate))
|
||||
@@ -208,7 +206,7 @@ class LoadVideo(ComfyNodeABC):
|
||||
FUNCTION = "load_video"
|
||||
def load_video(self, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return (VideoFromFile(video_path),)
|
||||
return (InputImpl.VideoFromFile(video_path),)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, file):
|
||||
@@ -239,3 +237,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GetVideoComponents": "Get Video Components",
|
||||
"LoadVideo": "Load Video",
|
||||
}
|
||||
|
||||
|
||||
@@ -9,29 +9,35 @@ import comfy.clip_vision
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class WanImageToVideo:
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
@@ -51,32 +57,36 @@ class WanImageToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanFunControlToVideo:
|
||||
class WanFunControlToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"control_video": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanFunControlToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
@@ -101,32 +111,96 @@ class WanFunControlToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class WanFirstLastFrameToVideo:
|
||||
class Wan22FunControlToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Wan22FunControlToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("ref_image", optional=True),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
|
||||
ref_latent = None
|
||||
if ref_image is not None:
|
||||
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class WanFirstLastFrameToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanFirstLastFrameToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_start_image", optional=True),
|
||||
io.ClipVisionOutput.Input("clip_vision_end_image", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("end_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
@@ -149,6 +223,7 @@ class WanFirstLastFrameToVideo:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
clip_vision_output = None
|
||||
if clip_vision_start_image is not None:
|
||||
clip_vision_output = clip_vision_start_image
|
||||
|
||||
@@ -166,62 +241,70 @@ class WanFirstLastFrameToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanFunInpaintToVideo:
|
||||
class WanFunInpaintToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanFunInpaintToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("end_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
flfv = WanFirstLastFrameToVideo()
|
||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||
return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||
|
||||
|
||||
class WanVaceToVideo:
|
||||
class WanVaceToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"control_video": ("IMAGE", ),
|
||||
"control_masks": ("MASK", ),
|
||||
"reference_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanVaceToVideo",
|
||||
category="conditioning/video_models",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
io.Mask.Input("control_masks", optional=True),
|
||||
io.Image.Input("reference_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput:
|
||||
latent_length = ((length - 1) // 4) + 1
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
@@ -278,52 +361,59 @@ class WanVaceToVideo:
|
||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent, trim_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent, trim_latent)
|
||||
|
||||
class TrimVideoLatent:
|
||||
class TrimVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TrimVideoLatent",
|
||||
category="latent/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Int.Input("trim_amount", default=0, min=0, max=99999),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, samples, trim_amount):
|
||||
@classmethod
|
||||
def execute(cls, samples, trim_amount) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class WanCameraImageToVideo:
|
||||
class WanCameraImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanCameraImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.WanCameraEmbedding.Input("camera_conditions", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
@@ -332,9 +422,12 @@ class WanCameraImageToVideo:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
|
||||
|
||||
if camera_conditions is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
|
||||
@@ -346,29 +439,34 @@ class WanCameraImageToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class WanPhantomSubjectToVideo:
|
||||
class WanPhantomSubjectToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"images": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanPhantomSubjectToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("images", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative_text"),
|
||||
io.Conditioning.Output(display_name="negative_img_text"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond2 = negative
|
||||
if images is not None:
|
||||
@@ -384,7 +482,7 @@ class WanPhantomSubjectToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, cond2, negative, out_latent)
|
||||
return io.NodeOutput(positive, cond2, negative, out_latent)
|
||||
|
||||
def parse_json_tracks(tracks):
|
||||
"""Parse JSON track data into a standardized format"""
|
||||
@@ -597,39 +695,41 @@ def patch_motion(
|
||||
|
||||
return out_mask_full, out_feature_full
|
||||
|
||||
class WanTrackToVideo:
|
||||
class WanTrackToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
|
||||
"start_image": ("IMAGE", ),
|
||||
},
|
||||
"optional": {
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanTrackToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.String.Input("tracks", multiline=True, default="[]"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input("topk", default=2, min=1, max=10),
|
||||
io.Image.Input("start_image"),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
|
||||
temperature, topk, start_image=None, clip_vision_output=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size,
|
||||
temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
|
||||
tracks_data = parse_json_tracks(tracks)
|
||||
|
||||
if not tracks_data:
|
||||
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
||||
return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
||||
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
@@ -683,34 +783,36 @@ class WanTrackToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class Wan22ImageToVideoLatent:
|
||||
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Wan22ImageToVideoLatent",
|
||||
category="conditioning/inpaint",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, vae, width, height, length, batch_size, start_image=None):
|
||||
@classmethod
|
||||
def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (out_latent,)
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
|
||||
@@ -725,18 +827,25 @@ class Wan22ImageToVideoLatent:
|
||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanTrackToVideo": WanTrackToVideo,
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
||||
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
|
||||
}
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanTrackToVideo,
|
||||
WanImageToVideo,
|
||||
WanFunControlToVideo,
|
||||
Wan22FunControlToVideo,
|
||||
WanFunInpaintToVideo,
|
||||
WanFirstLastFrameToVideo,
|
||||
WanVaceToVideo,
|
||||
TrimVideoLatent,
|
||||
WanCameraImageToVideo,
|
||||
WanPhantomSubjectToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
return WanExtension()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.47"
|
||||
__version__ = "0.3.50"
|
||||
|
||||
149
execution.py
149
execution.py
@@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
@@ -32,6 +32,8 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -56,7 +58,15 @@ class IsChangedCache:
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if not hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = False
|
||||
is_changed_name = None
|
||||
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||
has_is_changed = True
|
||||
is_changed_name = "fingerprint_inputs"
|
||||
elif hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = True
|
||||
is_changed_name = "IS_CHANGED"
|
||||
if not has_is_changed:
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -65,9 +75,9 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
@@ -126,9 +136,14 @@ class CacheSet:
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
else:
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
hidden_inputs_v3 = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
@@ -153,22 +168,37 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
if is_v3:
|
||||
if schema.hidden:
|
||||
if io.Hidden.prompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
if io.Hidden.dynprompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
if io.Hidden.unique_id in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys, hidden_inputs_v3
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@@ -184,7 +214,7 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -214,7 +244,22 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
f = getattr(obj, func)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# V1
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
@@ -266,8 +311,8 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@@ -298,6 +343,26 @@ def get_output_from_returns(return_values, obj):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
elif isinstance(r, _NodeOutputInternal):
|
||||
# V3
|
||||
if r.ui is not None:
|
||||
if isinstance(r.ui, dict):
|
||||
uis.append(r.ui)
|
||||
else:
|
||||
uis.append(r.ui.as_dict())
|
||||
if r.expand is not None:
|
||||
has_subgraph = True
|
||||
new_graph = r.expand
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif r.result is not None:
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
@@ -381,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -391,8 +456,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
@@ -424,7 +493,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
@@ -577,8 +646,6 @@ class PromptExecutor:
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
@@ -672,8 +739,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
validate_function_name = "validate_inputs"
|
||||
validate_function = first_real_override(obj_class, validate_function_name)
|
||||
else:
|
||||
validate_function_name = "VALIDATE_INPUTS"
|
||||
validate_function = getattr(obj_class, validate_function_name, None)
|
||||
if validate_function is not None:
|
||||
argspec = inspect.getfullargspec(validate_function)
|
||||
validate_function_inputs = argspec.args
|
||||
validate_has_kwargs = argspec.varkw is not None
|
||||
received_types = {}
|
||||
@@ -848,7 +921,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
@@ -856,8 +929,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
@@ -891,7 +963,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -915,7 +987,8 @@ async def validate_prompt(prompt_id, prompt):
|
||||
return (False, error, [], {})
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
outputs.add(x)
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
error = {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user