Compare commits

..

2 Commits

Author SHA1 Message Date
bymyself
700ce6b956 refactor 2025-11-18 12:05:11 -08:00
bymyself
9a67426b3a update templates for monorepo 2025-11-16 00:21:19 -08:00
38 changed files with 398 additions and 1747 deletions

View File

@@ -1,21 +0,0 @@
<!-- API_NODE_PR_CHECKLIST: do not remove -->
## API Node PR Checklist
### Scope
- [ ] **Is API Node Change**
### Pricing & Billing
- [ ] **Need pricing update**
- [ ] **No pricing update**
If **Need pricing update**:
- [ ] Metronome rate cards updated
- [ ] Autobilling tests updated and passing
### QA
- [ ] **QA done**
- [ ] **QA not required**
### Comms
- [ ] Informed **Kosinkadink**

View File

@@ -1,58 +0,0 @@
name: Append API Node PR template
on:
pull_request_target:
types: [opened, reopened, synchronize, ready_for_review]
paths:
- 'comfy_api_nodes/**' # only run if these files changed
permissions:
contents: read
pull-requests: write
jobs:
inject:
runs-on: ubuntu-latest
steps:
- name: Ensure template exists and append to PR body
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const number = context.payload.pull_request.number;
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
let templateText;
try {
const res = await github.rest.repos.getContent({
owner,
repo,
path: templatePath,
ref: pr.base.ref
});
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
templateText = buf.toString('utf8');
} catch (e) {
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
return;
}
// Enforce the presence of the marker inside the template (for idempotence)
if (!templateText.includes(marker)) {
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
return;
}
// If the PR already contains the marker, do not append again.
const body = pr.body || '';
if (body.includes(marker)) {
core.info('Template already present in PR body; nothing to inject.');
return;
}
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
core.notice('API Node template appended to PR description.');

View File

@@ -14,7 +14,7 @@ jobs:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu130)"
name: "Release NVIDIA Default (cu129)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
@@ -43,23 +43,6 @@ jobs:
test_release: true
secrets: inherit
release_nvidia_cu126:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu126"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu126"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu126"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"

View File

@@ -21,15 +21,14 @@ jobs:
fail-fast: false
matrix:
# os: [macos, linux, windows]
# os: [macos, linux]
os: [linux]
python_version: ["3.10", "3.11", "3.12"]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
@@ -74,15 +73,14 @@ jobs:
strategy:
fail-fast: false
matrix:
# os: [macos, linux]
os: [linux]
os: [macos, linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""

View File

@@ -1,168 +0,0 @@
# The Comfy guide to Quantization
## How does quantization work?
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
```
absmax = max(abs(tensor))
scale = amax / max_dynamic_range_low_precision
# Quantization
tensor_q = (tensor / scale).to(low_precision_dtype)
# De-Quantization
tensor_dq = tensor_q.to(fp16) * scale
tensor_dq ~ tensor
```
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
## Quantization in Comfy
```
QuantizedTensor (torch.Tensor subclass)
↓ __torch_dispatch__
Two-Level Registry (generic + layout handlers)
MixedPrecisionOps + Metadata Detection
```
### Representation
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
A `Layout` class defines how a specific quantization format behaves:
- Required parameters
- Quantize method
- De-Quantize method
```python
from comfy.quant_ops import QuantizedLayout
class MyLayout(QuantizedLayout):
@classmethod
def quantize(cls, tensor, **kwargs):
# Convert to quantized format
qdata = ...
params = {'scale': ..., 'orig_dtype': tensor.dtype}
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
return qdata.to(orig_dtype) * scale
```
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
```python
from comfy.quant_ops import register_layout_op
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
def my_linear(func, args, kwargs):
# Extract tensors, call optimized kernel
...
```
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
### Mixed Precision
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
**Architecture:**
```python
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {} # Maps layer names to quantization configs
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
```
**Key mechanism:**
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
- If the layer name **is** in `_layer_quant_config`:
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
- Load associated quantization parameters (scales, block_size, etc.)
**Why it's needed:**
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
## Checkpoint Format
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
The quantized checkpoint will contain the same layers as the original checkpoint but:
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
### Scaling Parameters details
We define 4 possible scaling parameters that should cover most recipes in the near-future:
- **weight_scale**: quantization scalers for the weights
- **weight_scale_2**: global scalers in the context of double scaling
- **pre_quant_scale**: scalers used for smoothing salient weights
- **input_scale**: quantization scalers for the activations
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
### Quantization Metadata
The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
Example:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
```
## Creating Quantized Checkpoints
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
### Weight Quantization
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
### Calibration (for Activation Quantization)
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
1. **Collect statistics**: Run inference on N representative samples
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@@ -173,7 +173,7 @@ There is a portable standalone build for Windows that should work for running on
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
If you have trouble extracting it, right click the file -> properties -> unblock
@@ -183,9 +183,7 @@ Update your Nvidia drivers if it doesn't start.
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
@@ -202,7 +200,7 @@ comfy install
## Manual Install (Windows, Linux)
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
@@ -223,7 +221,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
@@ -244,7 +242,7 @@ RDNA 4 (RX 9000 series):
### Intel GPUs (Windows and Linux)
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch xpu, use the following command:
@@ -254,6 +252,10 @@ This is the command to install the Pytorch xpu nightly which might have some per
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
### NVIDIA
Nvidia users should install stable pytorch using this command:

View File

@@ -1,15 +1,15 @@
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
# TODO: remove this in a few months
SingleStreamBlock = None
DoubleStreamBlock = None
class ChromaModulationOut(ModulationOut):
@@ -48,6 +48,124 @@ class Approximator(nn.Module):
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()

View File

@@ -11,12 +11,12 @@ import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
DoubleStreamBlock,
SingleStreamBlock,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@@ -90,7 +90,6 @@ class Chroma(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -99,7 +98,7 @@ class Chroma(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)

View File

@@ -10,10 +10,12 @@ from torch import Tensor, nn
from einops import repeat
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import (
DoubleStreamBlock,
SingleStreamBlock,
Approximator,
)
from .layers import (
@@ -87,6 +89,7 @@ class ChromaRadiance(Chroma):
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -94,7 +97,6 @@ class ChromaRadiance(Chroma):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -107,7 +109,6 @@ class ChromaRadiance(Chroma):
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
modulation=False,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks)

View File

@@ -130,17 +130,13 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.modulation = modulation
if self.modulation:
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
@@ -151,9 +147,7 @@ class DoubleStreamBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
@@ -166,65 +160,46 @@ class DoubleStreamBlock(nn.Module):
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
del img_attn
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
del txt_attn
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
@@ -245,7 +220,6 @@ class SingleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
modulation=True,
dtype=None,
device=None,
operations=None
@@ -268,29 +242,19 @@ class SingleStreamBlock(nn.Module):
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else:
self.modulation = None
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
if self.modulation:
mod, _ = self.modulation(vec)
else:
mod = vec
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)

View File

@@ -7,8 +7,7 @@ import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
if pe is not None:
q, k = apply_rope(q, k, pe)
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x

View File

@@ -210,7 +210,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -222,22 +222,10 @@ class Flux(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
steps_h = h_len
steps_w = w_len
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@@ -253,7 +241,7 @@ class Flux(nn.Module):
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img, img_ids = self.process_img(x)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0

View File

@@ -236,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
@@ -248,20 +248,16 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states

View File

@@ -504,7 +504,6 @@ class LoadedModel:
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@@ -690,10 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@@ -1104,9 +1100,6 @@ def pin_memory(tensor):
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor) is not torch.nn.parameter.Parameter:
return False
if not is_device_cpu(tensor.device):
return False
@@ -1116,9 +1109,6 @@ def pin_memory(tensor):
#on the GPU async. So dont trust the CUDA API and guard here
return False
if not tensor.is_contiguous():
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False

View File

@@ -843,7 +843,7 @@ class ModelPatcher:
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
@@ -887,19 +887,13 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True
if cast_weight:
@@ -915,7 +909,6 @@ class ModelPatcher:
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@@ -928,9 +921,6 @@ class ModelPatcher:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
self.apply_hooks(self.forced_hooks, force_apply=True)

View File

@@ -58,8 +58,7 @@ except (ModuleNotFoundError, TypeError):
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try:
if comfy.model_management.is_nvidia():
cudnn_version = torch.backends.cudnn.version()
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
#TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logging.info("working around nvidia conv3d memory bug.")
@@ -78,10 +77,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
else:
dtype = input.dtype
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
@@ -114,9 +110,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype:
weight = weight.to(dtype=dtype)
if weight_has_function:
with wf_context:
weight = weight.to(dtype=dtype)
for f in s.weight_function:
weight = f(weight)
@@ -538,7 +534,18 @@ if CUBLAS_IS_AVAILABLE:
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
from .quant_ops import QuantizedTensor
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": "TensorCoreFP8Layout",
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
}
}
}
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
@@ -589,24 +596,23 @@ class MixedPrecisionOps(disable_weight_init):
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
mixin = QUANT_FORMAT_MIXINS[quant_format]
self.layout_type = mixin["layout_type"]
weight_scale_key = f"{prefix}weight_scale"
scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
'scale': state_dict.pop(scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
manually_loaded_keys.append(scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
for param_name, param_value in mixin["parameters"].items():
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
@@ -637,7 +643,7 @@ class MixedPrecisionOps(disable_weight_init):
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)

View File

@@ -74,12 +74,6 @@ def _copy_layout_params(params):
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
@@ -324,13 +318,13 @@ def generic_to_dtype_layout(func, args, kwargs):
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._qdata.copy_(src._qdata)
qt_dest._layout_type = src._layout_type
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
qt_dest._layout_params = _copy_layout_params(src._layout_params)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
@@ -342,26 +336,6 @@ def generic_copy_(func, args, kwargs):
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
@@ -404,13 +378,6 @@ class TensorCoreFP8Layout(QuantizedLayout):
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,

View File

@@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@@ -468,7 +468,6 @@ class SDTokenizer:
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
self.pad_left = pad_left
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@@ -523,12 +522,6 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
for i in range(amount):
tokens.insert(0, (self.pad_token, 1.0, 0))
else:
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
@@ -607,7 +600,7 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
self.pad_tokens(batch, remaining_length)
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = []
if self.start_token is not None:
@@ -621,11 +614,11 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if min_padding is not None:
self.pad_tokens(batch, min_padding)
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
self.pad_tokens(batch, self.max_length - len(batch))
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if min_length is not None and len(batch) < min_length:
self.pad_tokens(batch, min_length - len(batch))
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

View File

@@ -32,7 +32,6 @@ class Llama2Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_3BConfig:
@@ -54,7 +53,6 @@ class Qwen25_3BConfig:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_7BVLI_Config:
@@ -76,7 +74,6 @@ class Qwen25_7BVLI_Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Gemma2_2B_Config:
@@ -99,7 +96,6 @@ class Gemma2_2B_Config:
k_norm = None
sliding_attention = None
rope_scale = None
final_norm: bool = True
@dataclass
class Gemma3_4B_Config:
@@ -122,7 +118,6 @@ class Gemma3_4B_Config:
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
final_norm: bool = True
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -371,12 +366,7 @@ class Llama2_(nn.Module):
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
@@ -431,16 +421,14 @@ class Llama2_(nn.Module):
if i == intermediate_output:
intermediate = x.clone()
if self.norm is not None:
x = self.norm(x)
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)
return x, intermediate

View File

@@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
@@ -104,8 +104,6 @@ class Types:
VideoCodec = VideoCodec
VideoContainer = VideoContainer
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
ComfyAPI = ComfyAPI_latest

View File

@@ -27,7 +27,6 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
@@ -657,11 +656,11 @@ class LossMap(ComfyTypeIO):
@comfytype(io_type="VOXEL")
class Voxel(ComfyTypeIO):
Type = VOXEL
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
@comfytype(io_type="MESH")
class Mesh(ComfyTypeIO):
Type = MESH
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):

View File

@@ -1,11 +1,8 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH
__all__ = [
# Utility Types
"VideoContainer",
"VideoCodec",
"VideoComponents",
"VOXEL",
"MESH",
]

View File

@@ -1,12 +0,0 @@
import torch
class VOXEL:
def __init__(self, data: torch.Tensor):
self.data = data
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces

View File

@@ -1,230 +1,22 @@
from datetime import date
from enum import Enum
from typing import Any
from typing import Optional
from pydantic import BaseModel, Field
class GeminiSafetyCategory(str, Enum):
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
class GeminiSafetyThreshold(str, Enum):
OFF = "OFF"
BLOCK_NONE = "BLOCK_NONE"
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
class GeminiSafetySetting(BaseModel):
category: GeminiSafetyCategory
threshold: GeminiSafetyThreshold
class GeminiRole(str, Enum):
user = "user"
model = "model"
class GeminiMimeType(str, Enum):
application_pdf = "application/pdf"
audio_mpeg = "audio/mpeg"
audio_mp3 = "audio/mp3"
audio_wav = "audio/wav"
image_png = "image/png"
image_jpeg = "image/jpeg"
image_webp = "image/webp"
text_plain = "text/plain"
video_mov = "video/mov"
video_mpeg = "video/mpeg"
video_mp4 = "video/mp4"
video_mpg = "video/mpg"
video_avi = "video/avi"
video_wmv = "video/wmv"
video_mpegps = "video/mpegps"
video_flv = "video/flv"
class GeminiInlineData(BaseModel):
data: str | None = Field(
None,
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
)
mimeType: GeminiMimeType | None = Field(None)
class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
text: str | None = Field(None)
class GeminiTextPart(BaseModel):
text: str | None = Field(None)
class GeminiContent(BaseModel):
parts: list[GeminiPart] = Field([])
role: GeminiRole = Field(..., examples=["user"])
class GeminiSystemInstructionContent(BaseModel):
parts: list[GeminiTextPart] = Field(
...,
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
class GeminiFunctionDeclaration(BaseModel):
description: str | None = Field(None)
name: str = Field(...)
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
class GeminiTool(BaseModel):
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
class GeminiOffset(BaseModel):
nanos: int | None = Field(None, ge=0, le=999999999)
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
class GeminiVideoMetadata(BaseModel):
endOffset: GeminiOffset | None = Field(None)
startOffset: GeminiOffset | None = Field(None)
class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(1, ge=0.0, le=2.0)
topK: int | None = Field(40, ge=1)
topP: float | None = Field(0.95, ge=0.0, le=1.0)
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel
class GeminiImageConfig(BaseModel):
aspectRatio: str | None = Field(None)
imageSize: str | None = Field(None)
aspectRatio: Optional[str] = None
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: list[str] | None = Field(None)
imageConfig: GeminiImageConfig | None = Field(None)
responseModalities: Optional[list[str]] = None
imageConfig: Optional[GeminiImageConfig] = None
class GeminiImageGenerateContentRequest(BaseModel):
contents: list[GeminiContent] = Field(...)
generationConfig: GeminiImageGenerationConfig | None = Field(None)
safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
class GeminiGenerateContentRequest(BaseModel):
contents: list[GeminiContent] = Field(...)
generationConfig: GeminiGenerationConfig | None = Field(None)
safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
class Modality(str, Enum):
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
TEXT = "TEXT"
IMAGE = "IMAGE"
VIDEO = "VIDEO"
AUDIO = "AUDIO"
DOCUMENT = "DOCUMENT"
class ModalityTokenCount(BaseModel):
modality: Modality | None = None
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
class Probability(str, Enum):
NEGLIGIBLE = "NEGLIGIBLE"
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
UNKNOWN = "UNKNOWN"
class GeminiSafetyRating(BaseModel):
category: GeminiSafetyCategory | None = None
probability: Probability | None = Field(
None,
description="The probability that the content violates the specified safety category",
)
class GeminiCitation(BaseModel):
authors: list[str] | None = None
endIndex: int | None = None
license: str | None = None
publicationDate: date | None = None
startIndex: int | None = None
title: str | None = None
uri: str | None = None
class GeminiCitationMetadata(BaseModel):
citations: list[GeminiCitation] | None = None
class GeminiCandidate(BaseModel):
citationMetadata: GeminiCitationMetadata | None = None
content: GeminiContent | None = None
finishReason: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiPromptFeedback(BaseModel):
blockReason: str | None = None
blockReasonMessage: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiUsageMetadata(BaseModel):
cachedContentTokenCount: int | None = Field(
None,
description="Output only. Number of tokens in the cached part in the input (the cached content).",
)
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of candidate tokens by modality."
)
promptTokenCount: int | None = Field(
None,
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
)
promptTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of prompt tokens by modality."
)
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
class GeminiGenerateContentResponse(BaseModel):
candidates: list[GeminiCandidate] | None = Field(None)
promptFeedback: GeminiPromptFeedback | None = Field(None)
usageMetadata: GeminiUsageMetadata | None = Field(None)
modelVersion: str | None = Field(None)
contents: list[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[list[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[list[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None

View File

@@ -1,133 +0,0 @@
from typing import Optional, Union
from pydantic import BaseModel, Field
class ImageEnhanceRequest(BaseModel):
model: str = Field("Reimagine")
output_format: str = Field("jpeg")
subject_detection: str = Field("All")
face_enhancement: bool = Field(True)
face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false")
face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false")
source_url: str = Field(...)
output_width: Optional[int] = Field(None)
output_height: Optional[int] = Field(None)
crop_to_fill: bool = Field(False)
prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance")
creativity: int = Field(3, description="Creativity settings range from 1 to 9")
face_preservation: str = Field("true", description="To preserve the identity of characters")
color_preservation: str = Field("true", description="To preserve the original color")
class ImageAsyncTaskResponse(BaseModel):
process_id: str = Field(...)
class ImageStatusResponse(BaseModel):
process_id: str = Field(...)
status: str = Field(...)
progress: Optional[int] = Field(None)
credits: int = Field(...)
class ImageDownloadResponse(BaseModel):
download_url: str = Field(...)
expiry: int = Field(...)
class Resolution(BaseModel):
width: int = Field(...)
height: int = Field(...)
class CreateCreateVideoRequestSource(BaseModel):
container: str = Field(...)
size: int = Field(..., description="Size of the video file in bytes")
duration: int = Field(..., description="Duration of the video file in seconds")
frameCount: int = Field(..., description="Total number of frames in the video")
frameRate: int = Field(...)
resolution: Resolution = Field(...)
class VideoFrameInterpolationFilter(BaseModel):
model: str = Field(...)
slowmo: Optional[int] = Field(None)
fps: int = Field(...)
duplicate: bool = Field(...)
duplicate_threshold: float = Field(...)
class VideoEnhancementFilter(BaseModel):
model: str = Field(...)
auto: Optional[str] = Field(None, description="Auto, Manual, Relative")
focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects")
compression: Optional[float] = Field(None, description="Strength of compression recovery")
details: Optional[float] = Field(None, description="Amount of detail reconstruction")
prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing")
noise: Optional[float] = Field(None, description="Amount of noise reduction")
halo: Optional[float] = Field(None, description="Amount of halo reduction")
preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength")
blur: Optional[float] = Field(None, description="Amount of sharpness applied")
grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
class OutputInformationVideo(BaseModel):
resolution: Resolution = Field(...)
frameRate: int = Field(...)
audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert")
audioTransfer: str = Field(..., description="Copy, Convert, None")
dynamicCompressionLevel: str = Field(..., description="Low, Mid, High")
class Overrides(BaseModel):
isPaidDiffusion: bool = Field(True)
class CreateVideoRequest(BaseModel):
source: CreateCreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
class CreateVideoResponse(BaseModel):
requestId: str = Field(...)
class VideoAcceptResponse(BaseModel):
uploadId: str = Field(...)
urls: list[str] = Field(...)
class VideoCompleteUploadRequestPart(BaseModel):
partNum: int = Field(...)
eTag: str = Field(...)
class VideoCompleteUploadRequest(BaseModel):
uploadResults: list[VideoCompleteUploadRequestPart] = Field(...)
class VideoCompleteUploadResponse(BaseModel):
message: str = Field(..., description="Confirmation message")
class VideoStatusResponseEstimates(BaseModel):
cost: list[int] = Field(...)
class VideoStatusResponseDownloadUrl(BaseModel):
url: str = Field(...)
class VideoStatusResponse(BaseModel):
status: str = Field(...)
estimates: Optional[VideoStatusResponseEstimates] = Field(None)
progress: Optional[float] = Field(None)
message: Optional[str] = Field("")
download: Optional[VideoStatusResponseDownloadUrl] = Field(None)

View File

@@ -3,6 +3,8 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
from __future__ import annotations
import base64
import json
import os
@@ -10,7 +12,7 @@ import time
import uuid
from enum import Enum
from io import BytesIO
from typing import Literal
from typing import Literal, Optional
import torch
from typing_extensions import override
@@ -18,24 +20,23 @@ from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis.gemini_api import (
from comfy_api_nodes.apis import (
GeminiContent,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiImageConfig,
GeminiImageGenerateContentRequest,
GeminiImageGenerationConfig,
GeminiInlineData,
GeminiMimeType,
GeminiPart,
GeminiRole,
Modality,
)
from comfy_api_nodes.apis.gemini_api import (
GeminiImageConfig,
GeminiImageGenerateContentRequest,
GeminiImageGenerationConfig,
)
from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
bytesio_to_image_tensor,
get_number_of_images,
sync_op,
tensor_to_base64_string,
validate_string,
@@ -56,7 +57,6 @@ class GeminiModel(str, Enum):
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
gemini_2_5_pro = "gemini-2.5-pro"
gemini_2_5_flash = "gemini-2.5-flash"
gemini_3_0_pro = "gemini-3-pro-preview"
class GeminiImageModel(str, Enum):
@@ -103,16 +103,6 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
Returns:
List of response parts matching the requested type.
"""
if response.candidates is None:
if response.promptFeedback.blockReason:
feedback = response.promptFeedback
raise ValueError(
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
)
raise NotImplementedError(
"Gemini returned no response candidates. "
"Please report to ComfyUI repository with the example of workflow to reproduce this."
)
parts = []
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
@@ -149,49 +139,6 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
return torch.cat(image_tensors, dim=0)
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
if not response.modelVersion:
return None
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
input_tokens_price = 1.25
output_text_tokens_price = 10.0
output_image_tokens_price = 0.0
elif response.modelVersion in (
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-flash",
):
input_tokens_price = 0.30
output_text_tokens_price = 2.50
output_image_tokens_price = 0.0
elif response.modelVersion in (
"gemini-2.5-flash-image-preview",
"gemini-2.5-flash-image",
):
input_tokens_price = 0.30
output_text_tokens_price = 2.50
output_image_tokens_price = 30.0
elif response.modelVersion == "gemini-3-pro-preview":
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3-pro-image-preview":
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 120.0
else:
return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
for i in response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.thoughtsTokenCount:
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
return final_price / 1_000_000.0
class GeminiNode(IO.ComfyNode):
"""
Node to generate text responses from a Gemini model.
@@ -325,10 +272,10 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
images: Optional[torch.Tensor] = None,
audio: Optional[Input.Audio] = None,
video: Optional[Input.Video] = None,
files: Optional[list[GeminiPart]] = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@@ -353,15 +300,15 @@ class GeminiNode(IO.ComfyNode):
data=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role=GeminiRole.user,
role="user",
parts=parts,
)
]
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
# Get result output
output_text = get_text_from_response(response)
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
@@ -459,7 +406,7 @@ class GeminiInputFiles(IO.ComfyNode):
)
@classmethod
def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput:
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
"""Loads and formats input files for Gemini API."""
if GEMINI_INPUT_FILES is None:
GEMINI_INPUT_FILES = []
@@ -474,7 +421,7 @@ class GeminiImage(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="GeminiImageNode",
display_name="Nano Banana (Google Gemini Image)",
display_name="Google Gemini Image",
category="api node/image/Gemini",
description="Edit images synchronously via Google API.",
inputs=[
@@ -522,13 +469,6 @@ class GeminiImage(IO.ComfyNode):
"or otherwise generates 1:1 squares.",
optional=True,
),
IO.Combo.Input(
"response_modalities",
options=["IMAGE+TEXT", "IMAGE"],
tooltip="Choose 'IMAGE' for image-only output, or "
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
@@ -548,10 +488,9 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
files: list[GeminiPart] | None = None,
images: Optional[torch.Tensor] = None,
files: Optional[list[GeminiPart]] = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@@ -571,19 +510,20 @@ class GeminiImage(IO.ComfyNode):
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
GeminiContent(role="user", parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
responseModalities=["TEXT", "IMAGE"],
imageConfig=None if aspect_ratio == "auto" else image_config,
),
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
@@ -604,150 +544,9 @@ class GeminiImage(IO.ComfyNode):
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
class GeminiImage2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiImage2Node",
display_name="Nano Banana Pro (Google Gemini Image)",
category="api node/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text prompt describing the image to generate or the edits to apply. "
"Include any constraints, styles, or details the model should follow.",
default="",
),
IO.Combo.Input(
"model",
options=["gemini-3-pro-image-preview"],
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
IO.Combo.Input(
"aspect_ratio",
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, generates a 1:1 square.",
),
IO.Combo.Input(
"resolution",
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
"response_modalities",
options=["IMAGE+TEXT", "IMAGE"],
tooltip="Choose 'IMAGE' for image-only output, or "
"'IMAGE+TEXT' to return both the generated image and a text response.",
),
IO.Image.Input(
"images",
optional=True,
tooltip="Optional reference image(s). "
"To include multiple images, use the Batch Images node (up to 14).",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
model: str,
seed: int,
aspect_ratio: str,
resolution: str,
response_modalities: str,
images: torch.Tensor | None = None,
files: list[GeminiPart] | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images is not None:
if get_number_of_images(images) > 14:
raise ValueError("The current maximum number of supported images is 14.")
parts.extend(create_image_parts(images))
if files is not None:
parts.extend(files)
image_config = GeminiImageConfig(imageSize=resolution)
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_text = get_text_from_response(response)
if output_text:
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
output_text = output_text or "Empty response from Gemini model..."
return IO.NodeOutput(output_image, output_text)
class GeminiExtension(ComfyExtension):
@@ -756,7 +555,6 @@ class GeminiExtension(ComfyExtension):
return [
GeminiNode,
GeminiImage,
GeminiImage2,
GeminiInputFiles,
]

View File

@@ -518,9 +518,7 @@ async def execute_lipsync(
# Upload the audio file to Comfy API and get download URL
if audio:
audio_url = await upload_audio_to_comfyapi(
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
)
audio_url = await upload_audio_to_comfyapi(cls, audio)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else:
audio_url = None

View File

@@ -1,421 +0,0 @@
import builtins
from io import BytesIO
import aiohttp
import torch
from typing_extensions import override
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
get_fs_object_size,
get_number_of_images,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_container_format_is_mp4,
)
UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
}
UPSCALER_VALUES_MAP = {
"FullHD (1080p)": 1920,
"4K (2160p)": 3840,
}
class TopazImageEnhance(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazImageEnhance",
display_name="Topaz Image Enhance",
category="api node/image/Topaz",
description="Industry-standard upscaling and image enhancement.",
inputs=[
IO.Combo.Input("model", options=["Reimagine"]),
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional text prompt for creative upscaling guidance.",
optional=True,
),
IO.Combo.Input(
"subject_detection",
options=["All", "Foreground", "Background"],
optional=True,
),
IO.Boolean.Input(
"face_enhancement",
default=True,
optional=True,
tooltip="Enhance faces (if present) during processing.",
),
IO.Float.Input(
"face_enhancement_creativity",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Set the creativity level for face enhancement.",
),
IO.Float.Input(
"face_enhancement_strength",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Controls how sharp enhanced faces are relative to the background.",
),
IO.Boolean.Input(
"crop_to_fill",
default=False,
optional=True,
tooltip="By default, the image is letterboxed when the output aspect ratio differs. "
"Enable to crop the image to fill the output dimensions.",
),
IO.Int.Input(
"output_width",
default=0,
min=0,
max=32000,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).",
),
IO.Int.Input(
"output_height",
default=0,
min=0,
max=32000,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Zero value means to output in the same height as original or output width.",
),
IO.Int.Input(
"creativity",
default=3,
min=1,
max=9,
step=1,
display_mode=IO.NumberDisplay.slider,
optional=True,
),
IO.Boolean.Input(
"face_preservation",
default=True,
optional=True,
tooltip="Preserve subjects' facial identity.",
),
IO.Boolean.Input(
"color_preservation",
default=True,
optional=True,
tooltip="Preserve the original colors.",
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str = "",
subject_detection: str = "All",
face_enhancement: bool = True,
face_enhancement_creativity: float = 1.0,
face_enhancement_strength: float = 0.8,
crop_to_fill: bool = False,
output_width: int = 0,
output_height: int = 0,
creativity: int = 3,
face_preservation: bool = True,
color_preservation: bool = True,
) -> IO.NodeOutput:
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
response_model=topaz_api.ImageAsyncTaskResponse,
data=topaz_api.ImageEnhanceRequest(
model=model,
prompt=prompt,
subject_detection=subject_detection,
face_enhancement=face_enhancement,
face_enhancement_creativity=face_enhancement_creativity,
face_enhancement_strength=face_enhancement_strength,
crop_to_fill=crop_to_fill,
output_width=output_width if output_width else None,
output_height=output_height if output_height else None,
creativity=creativity,
face_preservation=str(face_preservation).lower(),
color_preservation=str(color_preservation).lower(),
source_url=download_url[0],
output_format="png",
),
content_type="multipart/form-data",
)
await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"),
response_model=topaz_api.ImageStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: x.credits * 0.08,
poll_interval=8.0,
max_poll_attempts=160,
estimated_duration=60,
)
results = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"),
response_model=topaz_api.ImageDownloadResponse,
monitor_progress=False,
)
return IO.NodeOutput(await download_url_to_image_tensor(results.download_url))
class TopazVideoEnhance(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
IO.Combo.Input(
"upscaler_creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creativity level (applies only to Starlight (Astra) Creative).",
optional=True,
),
IO.Boolean.Input("interpolation_enabled", default=False, optional=True),
IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
optional=True,
),
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
optional=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
optional=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
optional=True,
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
upscaler_enabled: bool,
upscaler_model: str,
upscaler_resolution: str,
upscaler_creativity: str = "low",
interpolation_enabled: bool = False,
interpolation_model: str = "apo-8",
interpolation_slowmo: int = 1,
interpolation_frame_rate: int = 60,
interpolation_duplicate: bool = False,
interpolation_duplicate_threshold: float = 0.01,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
if upscaler_enabled is False and interpolation_enabled is False:
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
src_width, src_height = video.get_dimensions()
video_components = video.get_components()
src_frame_rate = int(video_components.frame_rate)
duration_sec = video.get_duration()
estimated_frames = int(duration_sec * src_frame_rate)
validate_container_format_is_mp4(video)
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_enabled:
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
filters.append(
topaz_api.VideoEnhancementFilter(
model=UPSCALER_MODELS_MAP[upscaler_model],
creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
),
)
if interpolation_enabled:
target_frame_rate = interpolation_frame_rate
filters.append(
topaz_api.VideoFrameInterpolationFilter(
model=interpolation_model,
slowmo=interpolation_slowmo,
fps=interpolation_frame_rate,
duplicate=interpolation_duplicate,
duplicate_threshold=interpolation_duplicate_threshold,
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=topaz_api.CreateVideoResponse,
data=topaz_api.CreateVideoRequest(
source=topaz_api.CreateCreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=estimated_frames,
frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height),
),
filters=filters,
output=topaz_api.OutputInformationVideo(
resolution=topaz_api.Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=topaz_api.VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=topaz_api.VideoCompleteUploadResponse,
data=topaz_api.VideoCompleteUploadRequest(
uploadResults=[
topaz_api.VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=topaz_api.VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
max_poll_attempts=320,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TopazImageEnhance,
TopazVideoEnhance,
]
async def comfy_entrypoint() -> TopazExtension:
return TopazExtension()

View File

@@ -63,7 +63,6 @@ class _RequestConfig:
estimated_total: Optional[int] = None
final_label_on_success: Optional[str] = "Completed"
progress_origin_ts: Optional[float] = None
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
@dataclass
@@ -78,9 +77,9 @@ class _PollUIState:
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
async def sync_op(
@@ -88,7 +87,6 @@ async def sync_op(
endpoint: ApiEndpoint,
*,
response_model: Type[M],
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
data: Optional[BaseModel] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
@@ -106,7 +104,6 @@ async def sync_op(
raw = await sync_op_raw(
cls,
endpoint,
price_extractor=_wrap_model_extractor(response_model, price_extractor),
data=data,
files=files,
content_type=content_type,
@@ -178,7 +175,6 @@ async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
@@ -220,7 +216,6 @@ async def sync_op_raw(
estimated_total=estimated_duration,
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -429,9 +424,7 @@ def _display_text(
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None:
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
if p != "0":
display_lines.append(f"Price: ${p}")
display_lines.append(f"Price: ${float(price):,.4f}")
if text is not None:
display_lines.append(text)
if display_lines:
@@ -587,7 +580,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: Optional[int] = None
extracted_price: Optional[float] = None
while True:
attempt += 1
stop_event = asyncio.Event()
@@ -775,8 +767,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
@@ -881,7 +871,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
else int(time.monotonic() - start_time)
),
estimated_total=cfg.estimated_total,
price=extracted_price,
price=None,
is_queued=False,
processing_elapsed_seconds=final_elapsed_seconds,
)

View File

@@ -11,13 +11,13 @@ if TYPE_CHECKING:
def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"]
x: torch.Tensor = args[0][:, :easycache.output_channels]
sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
timestep: float = args[1]
model_options: dict[str] = args[2]
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
# prepare next x_prev
x: torch.Tensor = args[0][:, :easycache.output_channels]
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
@@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
class EasyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
self.name = "EasyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
@@ -202,7 +202,6 @@ class EasyCacheHolder:
self.allow_mismatch = True
self.cut_from_start = True
self.state_metadata = None
self.output_channels = output_channels
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
@@ -265,7 +264,7 @@ class EasyCacheHolder:
else:
slicing.append(slice(None))
batch_slice = batch_slice + slicing
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
@@ -284,7 +283,7 @@ class EasyCacheHolder:
else:
slicing.append(slice(None))
skip_dim = False
x = x[tuple(slicing)]
x = x[slicing]
diff = output - x
batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
@@ -324,7 +323,7 @@ class EasyCacheHolder:
return self
def clone(self):
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
class EasyCacheNode(io.ComfyNode):
@@ -351,7 +350,7 @@ class EasyCacheNode(io.ComfyNode):
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
@@ -359,7 +358,7 @@ class EasyCacheNode(io.ComfyNode):
class LazyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
self.name = "LazyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
@@ -383,7 +382,6 @@ class LazyCacheHolder:
self.approx_output_change_rates = []
self.total_steps_skipped = 0
self.state_metadata = None
self.output_channels = output_channels
def has_cache_diff(self) -> bool:
return self.cache_diff is not None
@@ -458,7 +456,7 @@ class LazyCacheHolder:
return self
def clone(self):
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
class LazyCacheNode(io.ComfyNode):
@classmethod
@@ -484,7 +482,7 @@ class LazyCacheNode(io.ComfyNode):
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
return io.NodeOutput(model)

View File

@@ -7,79 +7,63 @@ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_fro
import folder_paths
import comfy.model_management
from comfy.cli_args import args
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
class EmptyLatentHunyuan3Dv2:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentHunyuan3Dv2",
category="latent/3d",
inputs=[
IO.Int.Input("resolution", default=3072, min=1, max=8192),
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
],
outputs=[
IO.Latent.Output(),
]
)
def INPUT_TYPES(s):
return {
"required": {
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}
}
@classmethod
def execute(cls, resolution, batch_size) -> IO.NodeOutput:
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/3d"
def generate(self, resolution, batch_size):
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"})
return ({"samples": latent, "type": "hunyuan3dv2"}, )
generate = execute # TODO: remove
class Hunyuan3Dv2Conditioning(IO.ComfyNode):
class Hunyuan3Dv2Conditioning:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2Conditioning",
category="conditioning/video_models",
inputs=[
IO.ClipVisionOutput.Input("clip_vision_output"),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
]
)
def INPUT_TYPES(s):
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
}}
@classmethod
def execute(cls, clip_vision_output) -> IO.NodeOutput:
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision_output):
embeds = clip_vision_output.last_hidden_state
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return IO.NodeOutput(positive, negative)
encode = execute # TODO: remove
return (positive, negative)
class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
class Hunyuan3Dv2ConditioningMultiView:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2ConditioningMultiView",
category="conditioning/video_models",
inputs=[
IO.ClipVisionOutput.Input("front", optional=True),
IO.ClipVisionOutput.Input("left", optional=True),
IO.ClipVisionOutput.Input("back", optional=True),
IO.ClipVisionOutput.Input("right", optional=True),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
]
)
def INPUT_TYPES(s):
return {"required": {},
"optional": {"front": ("CLIP_VISION_OUTPUT",),
"left": ("CLIP_VISION_OUTPUT",),
"back": ("CLIP_VISION_OUTPUT",),
"right": ("CLIP_VISION_OUTPUT",), }}
@classmethod
def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput:
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, front=None, left=None, back=None, right=None):
all_embeds = [front, left, back, right]
out = []
pos_embeds = None
@@ -92,35 +76,29 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
embeds = torch.cat(out, dim=1)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return IO.NodeOutput(positive, negative)
encode = execute # TODO: remove
return (positive, negative)
class VAEDecodeHunyuan3D(IO.ComfyNode):
class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeHunyuan3D",
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Int.Input("num_chunks", default=8000, min=1000, max=500000),
IO.Int.Input("octree_resolution", default=256, min=16, max=512),
],
outputs=[
IO.Voxel.Output(),
]
)
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ),
"vae": ("VAE", ),
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
}}
RETURN_TYPES = ("VOXEL",)
FUNCTION = "decode"
@classmethod
def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput:
voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return IO.NodeOutput(voxels)
decode = execute # TODO: remove
CATEGORY = "latent/3d"
def decode(self, vae, samples, num_chunks, octree_resolution):
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
@@ -418,24 +396,24 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
return final_vertices, faces
class MESH:
def __init__(self, vertices, faces):
self.vertices = vertices
self.faces = faces
class VoxelToMeshBasic(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMeshBasic",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
IO.Mesh.Output(),
]
)
class VoxelToMeshBasic:
@classmethod
def execute(cls, voxel, threshold) -> IO.NodeOutput:
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
CATEGORY = "3d"
def decode(self, voxel, threshold):
vertices = []
faces = []
for x in voxel.data:
@@ -443,29 +421,21 @@ class VoxelToMeshBasic(IO.ComfyNode):
vertices.append(v)
faces.append(f)
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
return (MESH(torch.stack(vertices), torch.stack(faces)), )
decode = execute # TODO: remove
class VoxelToMesh(IO.ComfyNode):
class VoxelToMesh:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMesh",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
IO.Mesh.Output(),
]
)
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"algorithm": (["surface net", "basic"], ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
@classmethod
def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput:
CATEGORY = "3d"
def decode(self, voxel, algorithm, threshold):
vertices = []
faces = []
@@ -479,9 +449,7 @@ class VoxelToMesh(IO.ComfyNode):
vertices.append(v)
faces.append(f)
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
decode = execute # TODO: remove
return (MESH(torch.stack(vertices), torch.stack(faces)), )
def save_glb(vertices, faces, filepath, metadata=None):
@@ -613,32 +581,31 @@ def save_glb(vertices, faces, filepath, metadata=None):
return filepath
class SaveGLB(IO.ComfyNode):
class SaveGLB:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
category="3d",
is_output_node=True,
inputs=[
IO.Mesh.Input("mesh"),
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
)
def INPUT_TYPES(s):
return {"required": {"mesh": ("MESH", ),
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
@classmethod
def execute(cls, mesh, filename_prefix) -> IO.NodeOutput:
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "3d"
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
@@ -649,22 +616,15 @@ class SaveGLB(IO.ComfyNode):
"type": "output"
})
counter += 1
return IO.NodeOutput(ui={"3d": results})
return {"ui": {"3d": results}}
class Hunyuan3dExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
EmptyLatentHunyuan3Dv2,
Hunyuan3Dv2Conditioning,
Hunyuan3Dv2ConditioningMultiView,
VAEDecodeHunyuan3D,
VoxelToMeshBasic,
VoxelToMesh,
SaveGLB,
]
async def comfy_entrypoint() -> Hunyuan3dExtension:
return Hunyuan3dExtension()
NODE_CLASS_MAPPINGS = {
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
"VoxelToMeshBasic": VoxelToMeshBasic,
"VoxelToMesh": VoxelToMesh,
"SaveGLB": SaveGLB,
}

View File

@@ -1,39 +0,0 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list
# "native" block swap nodes are placebo at best and break the ComfyUI memory management system.
# They are also considered harmful because instead of users reporting issues with the built in
# memory management they install these stupid nodes and complain even harder. Now it completely
# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it
# out of all workflows.
class wanBlockSwap(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="wanBlockSwap",
category="",
description="NOP",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(),
],
is_deprecated=True,
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
return io.NodeOutput(model)
class NopExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
wanBlockSwap
]
async def comfy_entrypoint() -> NopExtension:
return NopExtension()

View File

@@ -39,5 +39,5 @@ NODE_CLASS_MAPPINGS = {
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PreviewAny": "Preview as Text",
"PreviewAny": "Preview Any",
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.70"
__version__ = "0.3.68"

View File

@@ -1852,11 +1852,6 @@ class ImageBatch:
CATEGORY = "image"
def batch(self, image1, image2):
if image1.shape[-1] != image2.shape[-1]:
if image1.shape[-1] > image2.shape[-1]:
image2 = torch.nn.functional.pad(image2, (0,1), mode='constant', value=1.0)
else:
image1 = torch.nn.functional.pad(image1, (0,1), mode='constant', value=1.0)
if image1.shape[1:] != image2.shape[1:]:
image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
s = torch.cat((image1, image2), dim=0)
@@ -2335,7 +2330,6 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
"nodes_nop.py",
]
import_failed = []
@@ -2364,7 +2358,6 @@ async def init_builtin_api_nodes():
"nodes_pika.py",
"nodes_runway.py",
"nodes_sora.py",
"nodes_topaz.py",
"nodes_tripo.py",
"nodes_moonvalley.py",
"nodes_rodin.py",

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.70"
version = "0.3.68"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
@@ -24,7 +24,7 @@ lint.select = [
exclude = ["*.ipynb", "**/generated/*.pyi"]
[tool.pylint]
master.py-version = "3.10"
master.py-version = "3.9"
master.extension-pkg-allow-list = [
"pydantic",
]

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.28.9
comfyui-workflow-templates==0.6.0
comfyui-frontend-package==1.28.8
comfyui-workflow-templates==0.3.1
comfyui-embedded-docs==0.3.1
torch
torchsde

View File

@@ -2,7 +2,6 @@ import os
import sys
import asyncio
import traceback
import time
import nodes
import folder_paths
@@ -734,7 +733,6 @@ class PromptServer():
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)