Compare commits

...

44 Commits

Author SHA1 Message Date
comfyanonymous
265adad858 ComfyUI version v0.3.68 2025-11-04 19:42:23 -05:00
comfyanonymous
7f3e4d486c Limit amount of pinned memory on windows to prevent issues. (#10638) 2025-11-04 17:37:50 -05:00
rattus
a389ee01bb caching: Handle None outputs tuple case (#10637) 2025-11-04 14:14:10 -08:00
ComfyUI Wiki
9c71a66790 chore: update workflow templates to v0.2.11 (#10634) 2025-11-04 10:51:53 -08:00
comfyanonymous
af4b7b5edb More fp8 torch.compile regressions fixed. (#10625) 2025-11-03 22:14:20 -05:00
comfyanonymous
0f4ef3afa0 This seems to slow things down slightly on Linux. (#10624) 2025-11-03 21:47:14 -05:00
comfyanonymous
6b88478f9f Bring back fp8 torch compile performance to what it should be. (#10622) 2025-11-03 19:22:10 -05:00
comfyanonymous
e199c8cc67 Fixes (#10621) 2025-11-03 17:58:24 -05:00
comfyanonymous
0652cb8e2d Speed up torch.compile (#10620) 2025-11-03 17:37:12 -05:00
comfyanonymous
958a17199a People should update their pytorch versions. (#10618) 2025-11-03 17:08:30 -05:00
ComfyUI Wiki
e974e554ca chore: update embedded docs to v0.3.1 (#10614) 2025-11-03 10:59:44 -08:00
Alexander Piskun
4e2110c794 feat(Pika-API-nodes): use new API client (#10608) 2025-11-03 00:29:08 -08:00
Alexander Piskun
e617cddf24 convert nodes_openai.py to V3 schema (#10604) 2025-11-03 00:28:13 -08:00
Alexander Piskun
1f3f7a2823 convert nodes_hypernetwork.py to V3 schema (#10583) 2025-11-03 00:21:47 -08:00
EverNebula
88df172790 fix(caching): treat bytes as hashable (#10567) 2025-11-03 00:16:40 -08:00
Alexander Piskun
6d6a18b0b7 fix(api-nodes-cloud): stop using sub-folder and absolute path for output of Rodin3D nodes (#10556) 2025-11-03 00:04:56 -08:00
comfyanonymous
97ff9fae7e Clarify help text for --fast argument (#10609)
Updated help text for the --fast argument to clarify potential risks.
2025-11-02 13:14:04 -05:00
rattus
135fa49ec2 Small speed improvements to --async-offload (#10593)
* ops: dont take an offload stream if you dont need one

* ops: prioritize mem transfer

The async offload streams reason for existence is to transfer from
RAM to GPU. The post processing compute steps are a bonus on the side
stream, but if the compute stream is running a long kernel, it can
stall the side stream, as it wait to type-cast the bias before
transferring the weight. So do a pure xfer of the weight straight up,
then do everything bias, then go back to fix the weight type and do
weight patches.
2025-11-01 18:48:53 -04:00
comfyanonymous
44869ff786 Fix issue with pinned memory. (#10597) 2025-11-01 17:25:59 -04:00
Alexander Piskun
20182a393f convert StabilityAI to use new API client (#10582) 2025-11-01 12:14:06 -07:00
Alexander Piskun
5f109fe6a0 added 12s-20s as available output durations for the LTXV API nodes (#10570) 2025-11-01 12:13:39 -07:00
comfyanonymous
c58c13b2ba Fix torch compile regression on fp8 ops. (#10580) 2025-11-01 00:25:17 -04:00
comfyanonymous
7f374e42c8 ScaleROPE now works on Lumina models. (#10578) 2025-10-31 15:41:40 -04:00
comfyanonymous
27d1bd8829 Fix rope scaling. (#10560) 2025-10-30 22:51:58 -04:00
comfyanonymous
614cf9805e Add a ScaleROPE node. Currently only works on WAN models. (#10559) 2025-10-30 22:11:38 -04:00
rattus
513b0c46fb Add RAM Pressure cache mode (#10454)
* execution: Roll the UI cache into the outputs

Currently the UI cache is parallel to the output cache with
expectations of being a content superset of the output cache.
At the same time the UI and output cache are maintained completely
seperately, making it awkward to free the output cache content without
changing the behaviour of the UI cache.

There are two actual users (getters) of the UI cache. The first is
the case of a direct content hit on the output cache when executing a
node. This case is very naturally handled by merging the UI and outputs
cache.

The second case is the history JSON generation at the end of the prompt.
This currently works by asking the cache for all_node_ids and then
pulling the cache contents for those nodes. all_node_ids is the nodes
of the dynamic prompt.

So fold the UI cache into the output cache. The current UI cache setter
now writes to a prompt-scope dict. When the output cache is set, just
get this value from the dict and tuple up with the outputs.

When generating the history, simply iterate prompt-scope dict.

This prepares support for more complex caching strategies (like RAM
pressure caching) where less than 1 workflow will be cached and it
will be desirable to keep the UI cache and output cache in sync.

* sd: Implement RAM getter for VAE

* model_patcher: Implement RAM getter for ModelPatcher

* sd: Implement RAM getter for CLIP

* Implement RAM Pressure cache

Implement a cache sensitive to RAM pressure. When RAM headroom drops
down below a certain threshold, evict RAM-expensive nodes from the
cache.

Models and tensors are measured directly for RAM usage. An OOM score
is then computed based on the RAM usage of the node.

Note the due to indirection through shared objects (like a model
patcher), multiple nodes can account the same RAM as their individual
usage. The intent is this will free chains of nodes particularly
model loaders and associate loras as they all score similar and are
sorted in close to each other.

Has a bias towards unloading model nodes mid flow while being able
to keep results like text encodings and VAE.

* execution: Convert the cache entry to NamedTuple

As commented in review.

Convert this to a named tuple and abstract away the tuple type
completely from graph.py.
2025-10-30 17:39:02 -04:00
Alexander Piskun
dfac94695b fix img2img operation in Dall2 node (#10552) 2025-10-30 10:22:35 -07:00
Alexander Piskun
163b629c70 use new API client in Pixverse and Ideogram nodes (#10543) 2025-10-29 23:49:03 -07:00
Jedrzej Kosinski
998bf60beb Add units/info for the numbers displayed on 'load completely' and 'load partially' log messages (#10538) 2025-10-29 19:37:06 -04:00
comfyanonymous
906c089957 Fix small performance regression with fp8 fast and scaled fp8. (#10537) 2025-10-29 19:29:01 -04:00
comfyanonymous
25de7b1bfa Try to fix slow load issue on low ram hardware with pinned mem. (#10536) 2025-10-29 17:20:27 -04:00
rattus
ab7ab5be23 Fix Race condition in --async-offload that can cause corruption (#10501)
* mm: factor out the current stream getter

Make this a reusable function.

* ops: sync the offload stream with the consumption of w&b

This sync is nessacary as pytorch will queue cuda async frees on the
same stream as created to tensor. In the case of async offload, this
will be on the offload stream.

Weights and biases can go out of scope in python which then
triggers the pytorch garbage collector to queue the free operation on
the offload stream possible before the compute stream has used the
weight. This causes a use after free on weight data leading to total
corruption of some workflows.

So sync the offload stream with the compute stream after the weight
has been used so the free has to wait for the weight to be used.

The cast_bias_weight is extended in a backwards compatible way with
the new behaviour opt-in on a defaulted parameter. This handles
custom node packs calling cast_bias_weight and defeatures
async-offload for them (as they do not handle the race).

The pattern is now:

cast_bias_weight(... , offloadable=True) #This might be offloaded
thing(weight, bias, ...)
uncast_bias_weight(...)

* controlnet: adopt new cast_bias_weight synchronization scheme

This is nessacary for safe async weight offloading.

* mm: sync the last stream in the queue, not the next

Currently this peeks ahead to sync the next stream in the queue of
streams with the compute stream. This doesnt allow a lot of
parallelization, as then end result is you can only get one weight load
ahead regardless of how many streams you have.

Rotate the loop logic here to synchronize the end of the queue before
returning the next stream. This allows weights to be loaded ahead of the
compute streams position.
2025-10-29 17:17:46 -04:00
comfyanonymous
ec4fc2a09a Fix case of weights not being unpinned. (#10533) 2025-10-29 15:48:06 -04:00
comfyanonymous
1a58087ac2 Reduce memory usage for fp8 scaled op. (#10531) 2025-10-29 15:43:51 -04:00
Alexander Piskun
6c14f3afac use new API client in Luma and Minimax nodes (#10528) 2025-10-29 11:14:56 -07:00
comfyanonymous
e525673f72 Fix issue. (#10527) 2025-10-29 00:37:00 -04:00
comfyanonymous
3fa7a5c04a Speed up offloading using pinned memory. (#10526)
To enable this feature use: --fast pinned_memory
2025-10-29 00:21:01 -04:00
Alexander Piskun
210f7a1ba5 convert nodes_recraft.py to V3 schema (#10507) 2025-10-28 14:38:05 -07:00
rattus
d202c2ba74 execution: Allow a subgraph nodes to execute multiple times (#10499)
In the case of --cache-none lazy and subgraph execution can cause
anything to be run multiple times per workflow. If that rerun nodes is
in itself a subgraph generator, this will crash for two reasons.

pending_subgraph_results[] does not cleanup entries after their use.
So when a pending_subgraph_result is consumed, remove it from the list
so that if the corresponding node is fully re-executed this misses
lookup and it fall through to execute the node as it should.

Secondly, theres is an explicit enforcement against dups in the
addition of subgraphs nodes as ephemerals to the dymprompt. Remove this
enforcement as the use case is now valid.
2025-10-28 16:22:08 -04:00
contentis
8817f8fc14 Mixed Precision Quantization System (#10498)
* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

* Updated design using Tensor Subclasses

* Fix FP8 MM

* An actually functional POC

* Remove CK reference and ensure correct compute dtype

* Update unit tests

* ruff lint

* Fix missing keys

* Rename quant dtype parameter

* Rename quant dtype parameter

* Fix unittests for CPU build
2025-10-28 16:20:53 -04:00
comfyanonymous
22e40d2ace Tell users to update their nvidia drivers if portable doesn't start. (#10518) 2025-10-28 15:08:08 -04:00
comfyanonymous
3bea4efc6b Tell users to update nvidia drivers if problem with portable. (#10510) 2025-10-28 04:45:45 -04:00
comfyanonymous
8cf2ba4ba6 Remove comfy api key from queue api. (#10502) 2025-10-28 03:23:52 -04:00
comfyanonymous
b61a40cbc9 Bump stable portable to cu130 python 3.13.9 (#10508) 2025-10-28 03:21:45 -04:00
53 changed files with 3329 additions and 2852 deletions

View File

@@ -1,2 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@@ -18,9 +18,9 @@ jobs:
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
cache_tag: "cu130"
python_minor: "13"
python_patch: "6"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true

View File

@@ -176,6 +176,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)

View File

@@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -144,8 +145,9 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
PinnedMem = "pinned_memory"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

View File

@@ -310,11 +310,13 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@@ -350,12 +352,13 @@ class ControlLoraOps:
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@@ -522,7 +522,7 @@ class NextDiT(nn.Module):
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
@@ -531,10 +531,22 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

View File

@@ -588,7 +588,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@@ -601,10 +601,22 @@ class WanModel(torch.nn.Module):
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
@@ -630,7 +642,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):

View File

@@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:

View File

@@ -6,6 +6,20 @@ import math
import logging
import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config
def unet_prefix_from_state_dict(state_dict):

View File

@@ -1013,6 +1013,16 @@ if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@@ -1021,21 +1031,17 @@ def get_offload_stream(device):
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
@@ -1044,18 +1050,14 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None:
if stream is None or current_stream(device) is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -1080,6 +1082,60 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
if PerformanceFeature.PinnedMem in args.fast:
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
else:
MAX_PINNED_MEMORY = -1
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled():
return args.use_sage_attention

View File

@@ -238,6 +238,7 @@ class ModelPatcher:
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@@ -275,6 +276,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -294,6 +298,7 @@ class ModelPatcher:
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@@ -450,6 +455,19 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -618,6 +636,21 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -639,9 +672,11 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list()
load_completely = []
offloaded = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
@@ -658,6 +693,7 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@@ -683,6 +719,7 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
@@ -713,7 +750,9 @@ class ModelPatcher:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -721,11 +760,17 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@@ -762,6 +807,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@@ -857,6 +903,9 @@ class ModelPatcher:
memory_freed += module_mem
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
@@ -1259,5 +1308,6 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

View File

@@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
try:
if torch.cuda.is_available():
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
@@ -70,8 +70,11 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
dtype = input.dtype
@@ -80,32 +83,58 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None:
device = input.device
offload_stream = comfy.model_management.get_offload_stream(device)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight = weight.to(dtype=dtype)
if weight_has_function:
with wf_context:
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
comfy_cast_weights = False
@@ -118,8 +147,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -133,8 +164,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -148,8 +181,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -172,8 +207,10 @@ class disable_weight_init:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -187,8 +224,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -203,11 +242,14 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -223,11 +265,15 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -246,10 +292,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -268,10 +316,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -289,8 +339,11 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@@ -344,20 +397,18 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight
scale_input = self.scale_input
@@ -369,23 +420,20 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
uncast_bias_weight(self, w, bias, offload_stream)
return o
return None
@@ -405,8 +453,10 @@ class fp8_ops(manual_cast):
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@@ -434,12 +484,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
@@ -478,7 +530,130 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": "TensorCoreFP8Layout",
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
}
}
}
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
mixin = QUANT_FORMAT_MIXINS[quant_format]
self.layout_type = mixin["layout_type"]
scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
requires_grad=False
)
for param_name, param_value in mixin["parameters"].items():
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

512
comfy/quant_ops.py Normal file
View File

@@ -0,0 +1,512 @@
import torch
import logging
from typing import Tuple, Dict
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata)
qt_dest._layout_type = src._layout_type
qt_dest._layout_params = _copy_layout_params(src._layout_params)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@@ -143,6 +143,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -293,6 +296,7 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
@@ -595,6 +599,16 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
@@ -1262,7 +1276,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@@ -1296,7 +1310,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
new_sd = sd
@@ -1330,7 +1344,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
@@ -1346,8 +1363,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod

View File

@@ -1,261 +0,0 @@
from __future__ import annotations
import aiohttp
import mimetypes
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
import torch
import math
import base64
from .util import tensor_to_bytesio, bytesio_to_image_tensor
from io import BytesIO
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@@ -0,0 +1,120 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class MinimaxBaseResponse(BaseModel):
status_code: int = Field(
...,
description='Status code. 0 indicates success, other values indicate errors.',
)
status_msg: str = Field(
..., description='Specific error details or success message.'
)
class File(BaseModel):
bytes: Optional[int] = Field(None, description='File size in bytes')
created_at: Optional[int] = Field(
None, description='Unix timestamp when the file was created, in seconds'
)
download_url: Optional[str] = Field(
None, description='The URL to download the video'
)
backup_download_url: Optional[str] = Field(
None, description='The backup URL to download the video'
)
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
filename: Optional[str] = Field(None, description='The name of the file')
purpose: Optional[str] = Field(None, description='The purpose of using the file')
class MinimaxFileRetrieveResponse(BaseModel):
base_resp: MinimaxBaseResponse
file: File
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class Status6(str, Enum):
Queueing = 'Queueing'
Preparing = 'Preparing'
Processing = 'Processing'
Success = 'Success'
Fail = 'Fail'
class MinimaxTaskResultResponse(BaseModel):
base_resp: MinimaxBaseResponse
file_id: Optional[str] = Field(
None,
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
)
status: Status6 = Field(
...,
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
)
task_id: str = Field(..., description='The task ID being queried.')
class SubjectReferenceItem(BaseModel):
image: Optional[str] = Field(
None, description='URL or base64 encoding of the subject reference image.'
)
mask: Optional[str] = Field(
None,
description='URL or base64 encoding of the mask for the subject reference image.',
)
class MinimaxVideoGenerationRequest(BaseModel):
callback_url: Optional[str] = Field(
None,
description='Optional. URL to receive real-time status updates about the video generation task.',
)
first_frame_image: Optional[str] = Field(
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
prompt: Optional[str] = Field(
None,
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
max_length=2000,
)
prompt_optimizer: Optional[bool] = Field(
True,
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
)
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):
base_resp: MinimaxBaseResponse
task_id: str = Field(
..., description='The task ID for the asynchronous video generation task.'
)

View File

@@ -5,10 +5,6 @@ import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_aspect_ratio,
)
from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
@@ -23,8 +19,10 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
poll_op,
resize_mask_to_image,
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_string,
)
@@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
@@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
@classmethod
def validate_inputs(cls, aspect_ratio: str):
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
return True
@classmethod
@@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt=prompt,
prompt_upsampling=prompt_upsampling,
seed=seed,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
),
aspect_ratio=aspect_ratio,
raw=raw,
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
@@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
@@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
seed=0,
prompt_upsampling=False,
) -> IO.NodeOutput:
aspect_ratio = validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
if input_image is None:
validate_string(prompt, strip_whitespace=False)
initial_response = await sync_op(

View File

@@ -17,7 +17,7 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio_range,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
)
@@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
validate_image_aspect_ratio(image, (1, 3), (3, 1))
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
payload = Image2ImageTaskCreationRequest(
model=model,
@@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
reference_images_urls = []
if n_input_images:
for i in image:
validate_image_aspect_ratio_range(i, (1, 3), (3, 1))
validate_image_aspect_ratio(i, (1, 3), (3, 1))
reference_images_urls = await upload_images_to_comfyapi(
cls,
image,
@@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
prompt = (
@@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
for i in (first_frame, last_frame):
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
download_urls = await upload_images_to_comfyapi(
cls,
@@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
for image in images:
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
prompt = (

View File

@@ -1,6 +1,6 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.latest import IO, ComfyExtension
from PIL import Image
import numpy as np
import torch
@@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request,
IdeogramV3EditRequest,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image,
sync_op,
)
from server import PromptServer
V1_V1_RES_MAP = {
"Auto":"AUTO",
@@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor)
@@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(IO.ComfyNode):
@classmethod
@@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
@@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
character_image=None,
character_mask=None,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
@@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask
input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension
@@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
response_model=IdeogramGenerateResponse,
data=edit_request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request
gen_request = IdeogramV3Request(
prompt=prompt,
@@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
if files:
gen_request.style_type = "AUTO"
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=gen_request,
files=files if files else None,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls))
@@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
IdeogramV3,
]
async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension()

View File

@@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
"""
validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
def get_video_from_response(response) -> KlingVideoResult:

View File

@@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode):
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
@@ -85,6 +85,10 @@ class TextToVideoNode(IO.ComfyNode):
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
@@ -118,7 +122,7 @@ class ImageToVideoNode(IO.ComfyNode):
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
@@ -158,6 +162,10 @@ class ImageToVideoNode(IO.ComfyNode):
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
response = await sync_op_raw(

View File

@@ -1,69 +1,51 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
LumaVideoModel,
LumaVideoOutputResolution,
LumaVideoModelOutputDuration,
LumaAspectRatio,
LumaState,
LumaImageGenerationRequest,
LumaGenerationRequest,
LumaGeneration,
LumaCharacterRef,
LumaModifyImageRef,
LumaConceptChain,
LumaGeneration,
LumaGenerationRequest,
LumaImageGenerationRequest,
LumaImageIdentity,
LumaImageModel,
LumaImageReference,
LumaIO,
LumaKeyframes,
LumaModifyImageRef,
LumaReference,
LumaReferenceChain,
LumaImageReference,
LumaKeyframes,
LumaConceptChain,
LumaIO,
LumaVideoModel,
LumaVideoModelOutputDuration,
LumaVideoOutputResolution,
get_luma_concepts,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
process_image_response,
validate_string,
)
from server import PromptServer
from comfy_api_nodes.util import validate_string
import aiohttp
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(IO.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[
IO.Image.Input(
"image",
@@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode):
),
],
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> IO.NodeOutput:
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
@@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode):
class LumaConceptsNode(IO.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[
IO.Combo.Input(
"concept1",
@@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode):
),
],
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
@@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode):
class LumaImageGenerationNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
"prompt",
@@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode):
aspect_ratio: str,
seed,
style_image_weight: float,
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
image_luma_ref: Optional[LumaReferenceChain] = None,
style_image: Optional[torch.Tensor] = None,
character_image: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=auth_kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
)
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=aspect_ratio,
@@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
@classmethod
async def _convert_luma_refs(
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
luma_urls = []
ref_count = 0
for ref in luma_ref.refs:
download_urls = await upload_images_to_comfyapi(
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
luma_urls.append(download_urls[0])
ref_count += 1
if ref_count >= max_refs:
@@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
return await cls._convert_luma_refs(chain, max_refs=1)
class LumaImageModifyNode(IO.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[
IO.Image.Input(
"image",
@@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode):
image_weight: float,
seed,
) -> IO.NodeOutput:
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
image_url = download_urls[0]
# next, make Luma call with download url provided
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
response_model=LumaGeneration,
data=LumaImageGenerationRequest(
prompt=prompt,
model=model,
modify_image_ref=LumaModifyImageRef(
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
class LumaTextToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on prompt and output_size.",
inputs=[
IO.String.Input(
"prompt",
@@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
),
],
outputs=[IO.Video.Output()],
hidden=[
@@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
duration: str,
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
luma_concepts: Optional[LumaConceptChain] = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
prompt=prompt,
model=model,
resolution=resolution,
@@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
class LumaImageToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[
IO.String.Input(
"prompt",
@@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
),
],
outputs=[IO.Video.Output()],
hidden=[
@@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
luma_concepts: LumaConceptChain = None,
) -> IO.NodeOutput:
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
raise Exception("At least one of first_image and last_image requires an input.")
keyframes = await cls._convert_to_keyframes(first_image, last_image)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
response_model=LumaGeneration,
data=LumaGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
@@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
response_poll = await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
response_model=LumaGeneration,
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
@classmethod
async def _convert_to_keyframes(
cls,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
):
if first_image is None and last_image is None:
return None
frame0 = None
frame1 = None
if first_image is not None:
download_urls = await upload_images_to_comfyapi(
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None:
download_urls = await upload_images_to_comfyapi(
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1)

View File

@@ -1,71 +1,57 @@
from inspect import cleandoc
from typing import Optional
import logging
import torch
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.minimax_api import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,
MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
MiniMaxModel,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
)
from comfy_api_nodes.util import validate_string
from server import PromptServer
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
async def _generate_mm_video(
cls: type[IO.ComfyNode],
*,
auth: dict[str, str],
node_id: str,
prompt_text: str,
seed: int,
model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None,
) -> IO.NodeOutput:
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -73,81 +59,50 @@ async def _generate_mm_video(
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if node_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxTextToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[
IO.String.Input(
"prompt_text",
@@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
class MinimaxImageToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
"image",
@@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
class MinimaxSubjectToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
"subject",
@@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
seed: int = 0,
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
class MinimaxHailuoVideoNode(IO.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[
IO.String.Input(
"prompt_text",
@@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
resolution: str = "768P",
model: str = "MiniMax-Hailuo-02",
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
@@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
duration=duration,
resolution=resolution,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@@ -7,24 +7,23 @@ from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional, TypeVar
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_defs
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
EmptyRequest,
HttpMethod,
PollingOperation,
SynchronousOperation,
sync_op,
poll_op,
)
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio
R = TypeVar("R")
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
@@ -40,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
task_id = (await initial_operation.execute()).video_id
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=pika_defs.PikaVideoResponse,
),
completed_statuses=["finished"],
failed_statuses=["failed", "cancelled"],
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
node_id=node_id,
estimated_duration=60,
max_poll_attempts=240,
).execute()
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
@@ -124,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
resolution=resolution,
duration=duration,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
@@ -183,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -202,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
),
auth_kwargs=auth,
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
@@ -309,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
duration=duration,
aspectRatio=aspect_ratio,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
@@ -383,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
negativePrompt=negative_prompt,
seed=seed,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
@@ -472,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASWAPS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
@@ -528,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
@@ -547,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
@@ -592,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
@@ -612,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
),
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):

View File

@@ -1,7 +1,6 @@
from inspect import cleandoc
from typing import Optional
import torch
from typing_extensions import override
from io import BytesIO
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO,
pixverse_templates,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
download_url_to_video_output,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
)
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
import torch
import aiohttp
AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
def get_video_url_from_response(
response: PixverseGenerationStatusResponse,
) -> Optional[str]:
if response.Resp is None or response.Resp.url is None:
return None
return str(response.Resp.url)
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response_upload = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
response_model=PixverseImageUploadResponse,
files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id
@@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode):
class PixverseTextToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos based on prompt and output_size.",
inputs=[
IO.String.Input(
"prompt",
@@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
negative_prompt: str = None,
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
@@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTextVideoRequest(
prompt=prompt,
aspect_ratio=aspect_ratio,
quality=quality,
@@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseImageToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("image"),
IO.String.Input(
@@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
pixverse_template: int = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
img_id = await upload_image_to_pixverse(cls, image)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/img/generate",
method=HttpMethod.POST,
request_model=PixverseImageVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseImageVideoRequest(
img_id=img_id,
prompt=prompt,
quality=quality,
@@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseTransitionVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"),
@@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt: str = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/transition/generate",
method=HttpMethod.POST,
request_model=PixverseTransitionVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id,
last_frame_img=last_frame_id,
prompt=prompt,
@@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixVerseExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@@ -225,21 +225,20 @@ async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] =
async def download_files(url_list, task_uuid):
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
os.makedirs(save_path, exist_ok=True)
model_file_path = None
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
file_path = os.path.join(save_path, i.name)
if file_path.endswith(".glb"):
model_file_path = file_path
model_file_path = os.path.join(result_folder_name, i.name)
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
async with session.get(i.url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):

View File

@@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
@@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
@@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi(
@@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
reference_images = None
if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
reference_image,

View File

@@ -20,13 +20,6 @@ from comfy_api_nodes.apis.stability_api import (
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util import (
validate_audio_duration,
validate_string,
@@ -34,6 +27,9 @@ from comfy_api_nodes.util import (
bytesio_to_image_tensor,
tensor_to_bytesio,
audio_bytes_to_audio_input,
sync_op,
poll_op,
ApiEndpoint,
)
import torch
@@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
method=HttpMethod.POST,
request_model=StabilityStableUltraRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityStableUltraRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
method=HttpMethod.POST,
request_model=StabilityStable3_5Request,
response_model=StabilityStableUltraResponse,
),
request=StabilityStable3_5Request(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
method=HttpMethod.POST,
request_model=StabilityUpscaleConservativeRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityUpscaleConservativeRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
method=HttpMethod.POST,
request_model=StabilityUpscaleCreativeRequest,
response_model=StabilityAsyncResponse,
),
request=StabilityUpscaleCreativeRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/stability/v2beta/results/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=StabilityResultsGetResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
response_model=StabilityResultsGetResponse,
poll_interval=3,
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
)
response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=StabilityStableUltraResponse,
),
request=EmptyRequest(),
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
response_model=StabilityStableUltraResponse,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
@@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
method=HttpMethod.POST,
request_model=StabilityTextToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
method=HttpMethod.POST,
request_model=StabilityAudioToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
mask_start=mask_start,
mask_end=mask_end,
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
method=HttpMethod.POST,
request_model=StabilityAudioInpaintRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))

View File

@@ -14,9 +14,9 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_aspect_ratio_closeness,
validate_image_aspect_ratio_range,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
)
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
@@ -114,7 +114,7 @@ async def execute_task(
cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state.value,
status_extractor=lambda r: r.state,
estimated_duration=estimated_duration,
)
@@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
@@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
if a > 7:
raise ValueError("Too many images, maximum allowed is 7.")
for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest(
model_name=model,
@@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,

View File

@@ -14,9 +14,12 @@ from .conversions import (
downscale_image_tensor,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
video_to_base64_string,
)
@@ -34,12 +37,12 @@ from .upload_helpers import (
)
from .validation_utils import (
get_number_of_images,
validate_aspect_ratio_closeness,
validate_aspect_ratio_string,
validate_audio_duration,
validate_container_format_is_mp4,
validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string,
validate_video_dimensions,
validate_video_duration,
@@ -70,19 +73,22 @@ __all__ = [
"downscale_image_tensor",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"video_to_base64_string",
# Validation utilities
"get_number_of_images",
"validate_aspect_ratio_closeness",
"validate_aspect_ratio_string",
"validate_audio_duration",
"validate_container_format_is_mp4",
"validate_image_aspect_ratio",
"validate_image_aspect_ratio_range",
"validate_image_dimensions",
"validate_images_aspect_ratio_closeness",
"validate_string",
"validate_video_dimensions",
"validate_video_duration",

View File

@@ -77,8 +77,8 @@ class _PollUIState:
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"]
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
@@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"}
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:

View File

@@ -1,6 +1,7 @@
import base64
import logging
import math
import mimetypes
import uuid
from io import BytesIO
from typing import Optional
@@ -12,7 +13,7 @@ from PIL import Image
from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl
from comfy_api.util import VideoContainer, VideoCodec
from comfy_api.util import VideoCodec, VideoContainer
from ._helpers import mimetype_to_extension
@@ -430,3 +431,40 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
_, height, width, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"

View File

@@ -232,11 +232,12 @@ async def download_url_to_video_output(
video_url: str,
*,
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls)
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return VideoFromFile(result)

View File

@@ -37,63 +37,62 @@ def validate_image_dimensions(
def validate_image_aspect_ratio(
image: torch.Tensor,
min_aspect_ratio: Optional[float] = None,
max_aspect_ratio: Optional[float] = None,
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
a1, b1 = min_ratio
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
w, h = get_image_dimensions(image)
if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
def validate_aspect_ratio_closeness(
start_img,
end_img,
min_rel: float,
max_rel: float,
def validate_images_aspect_ratio_closeness(
first_image: torch.Tensor,
second_image: torch.Tensor,
min_rel: float, # e.g. 0.8
max_rel: float, # e.g. 1.25
*,
strict: bool = False, # True => exclusive, False => inclusive
) -> None:
w1, h1 = get_image_dimensions(start_img)
w2, h2 = get_image_dimensions(end_img)
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""
Validates that the two images' aspect ratios are 'close'.
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
Returns the computed closeness factor C.
"""
w1, h1 = get_image_dimensions(first_image)
w2, h2 = get_image_dimensions(second_image)
if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions")
ar1 = w1 / h1
ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
limit = max(max_rel, 1.0 / min_rel)
if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.")
raise ValueError(
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
f"allowed range {min_rel}{max_rel} (limit {limit:.2g})."
)
return closeness
def validate_aspect_ratio_string(
aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
ar = _parse_aspect_ratio_string(aspect_ratio)
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
def validate_video_dimensions(
@@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None:
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
def _ratio_from_tuple(r: tuple[float, float]) -> float:
a, b = r
if a <= 0 or b <= 0:
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
return a / b
def _assert_ratio_bounds(
ar: float,
*,
min_ratio: Optional[tuple[float, float]] = None,
max_ratio: Optional[tuple[float, float]] = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
if lo is not None and hi is not None and lo > hi:
lo, hi = hi, lo # normalize order if caller swapped them
if lo is not None:
if (ar <= lo) if strict else (ar < lo):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
if hi is not None:
if (ar >= hi) if strict else (ar > hi):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
def _parse_aspect_ratio_string(ar_str: str) -> float:
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
parts = ar_str.split(":")
if len(parts) != 2:
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
try:
a = int(parts[0].strip())
b = int(parts[1].strip())
except ValueError as exc:
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
if a <= 0 or b <= 0:
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
return a / b

View File

@@ -1,4 +1,9 @@
import bisect
import gc
import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
@@ -48,7 +53,7 @@ class Unhashable:
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
@@ -188,6 +193,9 @@ class BasicCache:
self._clean_cache()
self._clean_subcaches()
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
@@ -276,6 +284,9 @@ class NullCache:
def clean_unused(self):
pass
def poll(self, **kwargs):
pass
def get(self, node_id):
return None
@@ -336,3 +347,77 @@ class LRUCache(BasicCache):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_output_cache(self, from_node_id, to_node_id):
def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
return self.execution_cache[to_node_id].get(from_node_id)
value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:

View File

@@ -2,6 +2,9 @@ import comfy.utils
import folder_paths
import torch
import logging
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
@@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
return hypernetwork_patch(out, strength)
class HypernetworkLoader:
class HypernetworkLoader(IO.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork"
def define_schema(cls):
return IO.Schema(
node_id="HypernetworkLoader",
category="loaders",
inputs=[
IO.Model.Input("model"),
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
@classmethod
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return (model_hypernetwork,)
return IO.NodeOutput(model_hypernetwork)
NODE_CLASS_MAPPINGS = {
"HypernetworkLoader": HypernetworkLoader
}
load_hypernetwork = execute # TODO: remove
class HyperNetworkExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
HypernetworkLoader,
]
async def comfy_entrypoint() -> HyperNetworkExtension:
return HyperNetworkExtension()

View File

@@ -0,0 +1,47 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class ScaleROPE(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ScaleROPE",
category="advanced/model_patches",
description="Scale and shift the ROPE of the model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
m = model.clone()
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
return io.NodeOutput(m)
class RopeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ScaleROPE
]
async def comfy_entrypoint() -> RopeExtension:
return RopeExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.67"
__version__ = "0.3.68"

View File

@@ -21,6 +21,7 @@ from comfy_execution.caching import (
NullCache,
HierarchicalCache,
LRUCache,
RAMPressureCache,
)
from comfy_execution.graph import (
DynamicPrompt,
@@ -88,49 +89,56 @@ class IsChangedCache:
return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum):
CLASSIC = 0
LRU = 1
NONE = 2
RAM_PRESSURE = 3
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
cache_size = cache_args.get("lru", 0)
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
self.all = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache()
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
@@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
if cached_output is None:
cached = execution_list.get_cache(input_unique_id, unique_id)
if cached is None or cached.outputs is None:
mark_missing()
continue
if output_index >= len(cached_output):
if output_index >= len(cached.outputs):
mark_missing()
continue
obj = cached_output[output_index]
obj = cached.outputs[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
@@ -393,7 +401,7 @@ def format_value(x):
else:
return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
@@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
for o in node_output:
node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_cached.outputs[source_output]:
resolved_output.append(o)
else:
@@ -445,6 +456,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
del pending_subgraph_results[unique_id]
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
@@ -506,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
ui_outputs[unique_id] = {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
@@ -514,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"real_node_id": real_node_id,
},
"output": output_ui
})
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
@@ -527,10 +539,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
@@ -557,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
execution_list.cache_update(unique_id, output_data)
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -603,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
def __init__(self, server, cache_type=False, cache_args=None):
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True
@@ -685,6 +694,7 @@ class PromptExecutor:
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
@@ -698,7 +708,7 @@ class PromptExecutor:
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@@ -707,18 +717,16 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
@@ -1116,7 +1124,7 @@ class PromptQueue:
messages: List[str]
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']):
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
with self.mutex:
prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
@@ -1126,10 +1134,8 @@ class PromptQueue:
if status is not None:
status_dict = copy.deepcopy(status._asdict())
# Remove sensitive data from extra_data before storing in history
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in prompt[3]:
prompt[3].pop(sensitive_val)
if process_item is not None:
prompt = process_item(prompt)
self.history[prompt[1]] = {
"prompt": prompt,

15
main.py
View File

@@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@@ -192,14 +194,21 @@ def prompt_worker(q, server_instance):
prompt_id = item[1]
server_instance.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4])
sensitive = item[5]
extra_data = item[3].copy()
for k in sensitive:
extra_data[k] = sensitive[k]
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)

View File

@@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.67"
version = "0.3.68"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,6 +1,6 @@
comfyui-frontend-package==1.28.8
comfyui-workflow-templates==0.2.4
comfyui-embedded-docs==0.3.0
comfyui-workflow-templates==0.2.11
comfyui-embedded-docs==0.3.1
torch
torchsde
torchvision

View File

@@ -691,8 +691,9 @@ class PromptServer():
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
remove_sensitive = lambda queue: [x[:5] for x in queue]
queue_info['queue_running'] = remove_sensitive(current_queue[0])
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
return web.json_response(queue_info)
@routes.post("/prompt")
@@ -728,7 +729,11 @@ class PromptServer():
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
sensitive = {}
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:

View File

@@ -0,0 +1,232 @@
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy import ops
from comfy.quant_ops import QuantizedTensor
class SimpleModel(torch.nn.Module):
def __init__(self, operations=ops.disable_weight_init):
super().__init__()
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
def forward(self, x):
x = self.layer1(x)
x = torch.nn.functional.relu(x)
x = self.layer2(x)
x = torch.nn.functional.relu(x)
x = self.layer3(x)
return x
class TestMixedPrecisionOps(unittest.TestCase):
def test_all_layers_standard(self):
"""Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}
# Create model
model = SimpleModel(operations=ops.MixedPrecisionOps)
# Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
# Initialize weight_function and bias_function
for layer in [model.layer1, model.layer2, model.layer3]:
layer.weight_function = []
layer.bias_function = []
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
self.assertEqual(output.dtype, torch.bfloat16)
def test_mixed_precision_load(self):
"""Test loading a mixed precision model from state dict"""
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
},
"layer3": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
# Layer 1: FP8 E4M3FN
"layer1.weight": fp8_weight1,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
# Layer 2: Standard BF16
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
# Layer 3: FP8 E4M3FN
"layer3.weight": fp8_weight3,
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
}
# Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
# Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_state_dict_quantized_preserved(self):
"""Test that quantized weights are preserved in state_dict()"""
# Configure mixed precision
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict1 = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict1, strict=False)
# Save state dict
state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
def test_weight_function_compatibility(self):
"""Test that weight_function (LoRA) works with quantized layers"""
# Configure FP8 quantization
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA)
# This should trigger dequantization during forward pass
def apply_lora(weight):
lora_delta = torch.randn_like(weight) * 0.01
return weight + lora_delta
model.layer1.weight_function.append(apply_lora)
# Forward pass should work with LoRA (triggers weight_function path)
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_error_handling_unknown_format(self):
"""Test that unknown formats raise error"""
# Configure with unknown format
layer_quant_config = {
"layer1": {
"format": "unknown_format_xyz",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict
state_dict = {
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps)
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,190 @@
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
class TestQuantizedTensor(unittest.TestCase):
"""Test the QuantizedTensor subclass with FP8 layout"""
def test_creation(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
def test_dequantize(self):
"""Test explicit dequantization"""
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5)
qt = QuantizedTensor.from_float(
float_tensor,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt.shape, (64, 32))
# Verify dequantization gives approximately original values
dequantized = qt.dequantize()
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
self.assertLess(mean_rel_error, 0.1)
class TestGenericUtilities(unittest.TestCase):
"""Test generic utility operations"""
def test_detach(self):
"""Test detach operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Detach should return a new QuantizedTensor
qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
def test_clone(self):
"""Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Clone should return a new QuantizedTensor
qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
# Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata)
@unittest.skipUnless(has_gpu(), "GPU not available")
def test_to_device(self):
"""Test device transfer"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu')
self.assertIsInstance(qt_cpu, QuantizedTensor)
self.assertEqual(qt_cpu.device.type, 'cpu')
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
class TestTensorCoreFP8Layout(unittest.TestCase):
"""Test the TensorCoreFP8Layout implementation"""
def test_quantize(self):
"""Test quantization method"""
float_tensor = torch.randn(32, 64, dtype=torch.float32)
scale = torch.tensor(1.5)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
self.assertEqual(qdata.shape, float_tensor.shape)
self.assertIn('scale', layout_params)
self.assertIn('orig_dtype', layout_params)
self.assertEqual(layout_params['orig_dtype'], torch.float32)
def test_dequantize(self):
"""Test dequantization method"""
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
scale = torch.tensor(1.0)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
# Should approximately match original
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float(
a_fp32,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
# Call an operation that doesn't have a registered handler
# For example, torch.abs
result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, QuantizedTensor)
expected = torch.abs(a_fp32)
# FP8 introduces quantization error, so use loose tolerance
mean_error = (result - expected).abs().mean()
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
if __name__ == "__main__":
unittest.main()