Compare commits

..

1 Commits

41 changed files with 455 additions and 1956 deletions

View File

@@ -7,8 +7,6 @@ on:
jobs:
send-webhook:
runs-on: ubuntu-latest
env:
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps:
- name: Send release webhook
env:
@@ -108,37 +106,3 @@ jobs:
--fail --silent --show-error
echo "✅ Release webhook sent successfully"
- name: Send repository dispatch to desktop
env:
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.event.release.tag_name }}
RELEASE_URL: ${{ github.event.release.html_url }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_release_published",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from typing import TypedDict
import json
import logging
import os
import folder_paths
import glob
@@ -40,6 +42,7 @@ class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
self.distribution = os.environ.get("DISTRIBUTION", "localhost")
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
@@ -90,15 +93,56 @@ class SubgraphManager:
return subgraphs_dict
async def get_blueprint_subgraphs(self, force_reload=False):
"""Load subgraphs from the blueprints directory."""
"""Load subgraphs from the blueprints directory using index.json for discovery."""
if not force_reload and self.cached_blueprint_subgraphs is not None:
return self.cached_blueprint_subgraphs
subgraphs_dict: dict[SubgraphEntry] = {}
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
if os.path.exists(blueprints_dir):
index_path = os.path.join(blueprints_dir, "index.json")
if os.path.isfile(index_path):
try:
with open(index_path, "r", encoding="utf-8") as f:
categories = json.load(f)
except (json.JSONDecodeError, OSError) as e:
logging.error("Failed to load blueprint index %s: %s", index_path, e)
categories = []
if not isinstance(categories, list):
logging.error("Blueprint index.json is not a list: %s", index_path)
categories = []
for category in categories:
module_name = category.get("moduleName", "default")
for blueprint in category.get("blueprints", []):
name = blueprint.get("name")
if not name:
logging.warning("Blueprint entry missing 'name' in category '%s', skipping", module_name)
continue
include_on = blueprint.get("includeOnDistributions")
if include_on is not None and self.distribution not in include_on:
continue
file_by_dist = blueprint.get("fileByDistribution", {})
filename = file_by_dist.get(self.distribution, f"{name}.json")
filepath = os.path.realpath(os.path.join(blueprints_dir, filename))
if not filepath.startswith(os.path.realpath(blueprints_dir) + os.sep):
logging.warning("Blueprint path escapes blueprints directory: %s", filepath)
continue
if not os.path.isfile(filepath):
logging.warning("Blueprint file not found: %s", filepath)
continue
entry_id, entry = self._create_entry(filepath, Source.templates, module_name)
subgraphs_dict[entry_id] = entry
elif os.path.exists(blueprints_dir):
logging.warning("No blueprint index.json found at %s, falling back to glob", index_path)
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
if os.path.basename(file) == "index.json":
continue
file = file.replace('\\', '/')
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
subgraphs_dict[entry_id] = entry

View File

@@ -1,11 +1,12 @@
import math
import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import tqdm
from tqdm.auto import trange as trange_, tqdm
from . import utils
from . import deis
@@ -14,7 +15,34 @@ import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
from comfy.utils import model_trange as trange
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

View File

@@ -7,67 +7,6 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
from comfy.ldm.flux.layers import timestep_embedding
def get_silence_latent(length, device):
head = torch.tensor([[[ 0.5707, 0.0982, 0.6909, -0.5658, 0.6266, 0.6996, -0.1365, -0.1291,
-0.0776, -0.1171, -0.2743, -0.8422, -0.1168, 1.5539, -4.6936, 0.7436,
-1.1846, -0.2637, 0.6933, -6.7266, 0.0966, -0.1187, -0.3501, -1.1736,
0.0587, -2.0517, -1.3651, 0.7508, -0.2490, -1.3548, -0.1290, -0.7261,
1.1132, -0.3249, 0.2337, 0.3004, 0.6605, -0.0298, -0.1989, -0.4041,
0.2843, -1.0963, -0.5519, 0.2639, -1.0436, -0.1183, 0.0640, 0.4460,
-1.1001, -0.6172, -1.3241, 1.1379, 0.5623, -0.1507, -0.1963, -0.4742,
-2.4697, 0.5302, 0.5381, 0.4636, -0.1782, -0.0687, 1.0333, 0.4202],
[ 0.3040, -0.1367, 0.6200, 0.0665, -0.0642, 0.4655, -0.1187, -0.0440,
0.2941, -0.2753, 0.0173, -0.2421, -0.0147, 1.5603, -2.7025, 0.7907,
-0.9736, -0.0682, 0.1294, -5.0707, -0.2167, 0.3302, -0.1513, -0.8100,
-0.3894, -0.2884, -0.3149, 0.8660, -0.3817, -1.7061, 0.5824, -0.4840,
0.6938, 0.1859, 0.1753, 0.3081, 0.0195, 0.1403, -0.0754, -0.2091,
0.1251, -0.1578, -0.4968, -0.1052, -0.4554, -0.0320, 0.1284, 0.4974,
-1.1889, -0.0344, -0.8313, 0.2953, 0.5445, -0.6249, -0.1595, -0.0682,
-3.1412, 0.0484, 0.4153, 0.8260, -0.1526, -0.0625, 0.5366, 0.8473],
[ 5.3524e-02, -1.7534e-01, 5.4443e-01, -4.3501e-01, -2.1317e-03,
3.7200e-01, -4.0143e-03, -1.5516e-01, -1.2968e-01, -1.5375e-01,
-7.7107e-02, -2.0593e-01, -3.2780e-01, 1.5142e+00, -2.6101e+00,
5.8698e-01, -1.2716e+00, -2.4773e-01, -2.7933e-02, -5.0799e+00,
1.1601e-01, 4.0987e-01, -2.2030e-02, -6.6495e-01, -2.0995e-01,
-6.3474e-01, -1.5893e-01, 8.2745e-01, -2.2992e-01, -1.6816e+00,
5.4440e-01, -4.9579e-01, 5.5128e-01, 3.0477e-01, 8.3052e-02,
-6.1782e-02, 5.9036e-03, 2.9553e-01, -8.0645e-02, -1.0060e-01,
1.9144e-01, -3.8124e-01, -7.2949e-01, 2.4520e-02, -5.0814e-01,
2.3977e-01, 9.2943e-02, 3.9256e-01, -1.1993e+00, -3.2752e-01,
-7.2707e-01, 2.9476e-01, 4.3542e-01, -8.8597e-01, -4.1686e-01,
-8.5390e-02, -2.9018e+00, 6.4988e-02, 5.3945e-01, 9.1988e-01,
5.8762e-02, -7.0098e-02, 6.4772e-01, 8.9118e-01],
[-3.2225e-02, -1.3195e-01, 5.6411e-01, -5.4766e-01, -5.2170e-03,
3.1425e-01, -5.4367e-02, -1.9419e-01, -1.3059e-01, -1.3660e-01,
-9.0984e-02, -1.9540e-01, -2.5590e-01, 1.5440e+00, -2.6349e+00,
6.8273e-01, -1.2532e+00, -1.9810e-01, -2.2793e-02, -5.0506e+00,
1.8818e-01, 5.0109e-01, 7.3546e-03, -6.8771e-01, -3.0676e-01,
-7.3257e-01, -1.6687e-01, 9.2232e-01, -1.8987e-01, -1.7267e+00,
5.3355e-01, -5.3179e-01, 4.4953e-01, 2.8820e-01, 1.3012e-01,
-2.0943e-01, -1.1348e-01, 3.3929e-01, -1.5069e-01, -1.2919e-01,
1.8929e-01, -3.6166e-01, -8.0756e-01, 6.6387e-02, -5.8867e-01,
1.6978e-01, 1.0134e-01, 3.3877e-01, -1.2133e+00, -3.2492e-01,
-8.1237e-01, 3.8101e-01, 4.3765e-01, -8.0596e-01, -4.4531e-01,
-4.7513e-02, -2.9266e+00, 1.1741e-03, 4.5123e-01, 9.3075e-01,
5.3688e-02, -1.9621e-01, 6.4530e-01, 9.3870e-01]]], device=device).movedim(-1, 1)
silence_latent = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, length)
silence_latent[:, :, :head.shape[-1]] = head
return silence_latent
def get_layer_class(operations, layer_name):
if operations is not None and hasattr(operations, layer_name):
return getattr(operations, layer_name)
@@ -738,7 +677,7 @@ class AttentionPooler(nn.Module):
def forward(self, x):
B, T, P, D = x.shape
x = self.embed_tokens(x)
special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1)
special = self.special_token.expand(B, T, 1, -1)
x = torch.cat([special, x], dim=2)
x = x.view(B * T, P + 1, D)
@@ -789,7 +728,7 @@ class FSQ(nn.Module):
self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
def bound(self, z):
levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1)
levels_minus_1 = (self._levels - 1).to(z.dtype)
scale = 2. / levels_minus_1
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
@@ -804,8 +743,8 @@ class FSQ(nn.Module):
return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
def codes_to_indices(self, zhat):
zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1))
return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32)
zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
def forward(self, z):
orig_dtype = z.dtype
@@ -887,7 +826,7 @@ class ResidualFSQ(nn.Module):
x = self.project_in(x)
if hasattr(self, 'soft_clamp_input_value'):
sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype)
sc_val = self.soft_clamp_input_value.to(x.dtype)
x = (x / sc_val).tanh() * sc_val
quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
@@ -895,7 +834,7 @@ class ResidualFSQ(nn.Module):
all_indices = []
for layer, scale in zip(self.layers, self.scales):
scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype)
scale = scale.to(residual.dtype)
quantized, indices = layer(residual / scale)
quantized = quantized * scale
@@ -1101,21 +1040,22 @@ class AceStepConditionGenerationModel(nn.Module):
lm_hints = self.detokenizer(lm_hints_5Hz)
lm_hints = lm_hints[:, :src_latents.shape[1], :]
if is_covers is None or is_covers is True:
if is_covers is None:
src_latents = lm_hints
elif is_covers is False:
src_latents = refer_audio_acoustic_hidden_states_packed
else:
src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents)
context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1)
return encoder_hidden, encoder_mask, context_latents
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs):
text_attention_mask = None
lyric_attention_mask = None
refer_audio_order_mask = None
attention_mask = None
chunk_masks = None
is_covers = None
src_latents = None
precomputed_lm_hints_25Hz = None
lyric_hidden_states = lyric_embed
@@ -1127,7 +1067,7 @@ class AceStepConditionGenerationModel(nn.Module):
if refer_audio_order_mask is None:
refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
if src_latents is None:
if src_latents is None and is_covers is None:
src_latents = x
if chunk_masks is None:
@@ -1140,9 +1080,6 @@ class AceStepConditionGenerationModel(nn.Module):
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
)
if replace_with_null_embeds:
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
out = self.decoder(hidden_states=x,
timestep=timestep,
timestep_r=timestep,

View File

@@ -195,20 +195,8 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
def preprocess_text_embeds(self, text_embeds, text_ids):
if text_ids is not None:
out = self.llm_adapter(text_embeds, text_ids)
if t5xxl_weights is not None:
out = out * t5xxl_weights
if out.shape[1] < 512:
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
return out
return self.llm_adapter(text_embeds, text_ids)
else:
return text_embeds
def forward(self, x, timesteps, context, **kwargs):
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
if t5xxl_ids is not None:
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
return super().forward(x, timesteps, context, **kwargs)

View File

@@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
@@ -463,8 +463,6 @@ class Block(nn.Module):
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
residual_dtype = x_B_T_H_W_D.dtype
compute_dtype = emb_B_T_D.dtype
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -514,7 +512,7 @@ class Block(nn.Module):
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -524,7 +522,7 @@ class Block(nn.Module):
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
@@ -538,7 +536,7 @@ class Block(nn.Module):
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -557,7 +555,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
@@ -565,8 +563,8 @@ class Block(nn.Module):
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
return x_B_T_H_W_D
@@ -878,14 +876,6 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
# in fp32, but run attention and MLP modules in fp16.
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
# quality degradation and visual artifacts.
if x_B_T_H_W_D.dtype == torch.float16:
x_B_T_H_W_D = x_B_T_H_W_D.float()
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
@@ -894,6 +884,6 @@ class MiniTrainDIT(nn.Module):
**block_kwargs,
)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
return x_B_C_Tt_Hp_Wp

View File

@@ -29,34 +29,19 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device)
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try:
import comfy.quant_ops
q_apply_rope = comfy.quant_ops.ck.apply_rope
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
apply_rope = _apply_rope
apply_rope1 = _apply_rope1
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@@ -147,11 +147,11 @@ class BaseModel(torch.nn.Module):
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
comfy.model_management.archive_model_dtypes(self.diffusion_model)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
@@ -1160,16 +1160,12 @@ class Anima(BaseModel):
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
@@ -1556,8 +1552,6 @@ class ACEStep15(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if torch.count_nonzero(cross_attn) == 0:
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
@@ -1566,11 +1560,22 @@ class ACEStep15(BaseModel):
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
if refer_audio is None or len(refer_audio) == 0:
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
pass_audio_codes = True
else:
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
out['is_covers'] = comfy.conds.CONDConstant(True)
pass_audio_codes = False
if pass_audio_codes:
@@ -1578,12 +1583,6 @@ class ACEStep15(BaseModel):
if audio_codes is not None:
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
refer_audio = refer_audio[:, :, :750]
else:
out['is_covers'] = comfy.conds.CONDConstant(False)
if refer_audio.shape[2] < noise.shape[2]:
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
return out

View File

@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import threading
import torch
import sys
@@ -55,11 +55,6 @@ cpu_state = CPUState.GPU
total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types():
float8_types = []
try:
@@ -656,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -752,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_loaded_models.insert(0, loaded_model)
return
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
def load_model_gpu(model):
return load_models_gpu([model])
@@ -1211,20 +1226,21 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
return v_tensor.to(dtype=dtype)
r = torch.empty_like(weight, dtype=dtype, device=device)
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations

View File

@@ -19,6 +19,7 @@
from __future__ import annotations
import collections
import copy
import inspect
import logging
import math
@@ -316,7 +317,7 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
@@ -1491,9 +1492,7 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#We have way more tools for acceleration on comfy weight offloading, so always
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
@@ -1525,7 +1524,7 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function)
geometry = weight
if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
@@ -1551,14 +1550,13 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")

View File

@@ -83,18 +83,14 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@@ -144,13 +140,9 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -177,8 +169,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
else:
y = x
if update_weight:
orig.copy_(y)
for f in fns:
@@ -190,6 +182,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)

View File

@@ -122,26 +122,20 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
memory_required = 1e20
minimum_memory_required = None
else:
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models

View File

@@ -793,6 +793,8 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
model_management.archive_model_dtypes(self.first_stage_model)
if device is None:
device = model_management.vae_device()
self.device = device
@@ -801,7 +803,6 @@ class VAE:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
model_management.archive_model_dtypes(self.first_stage_model)
self.output_device = model_management.intermediate_device()
mp = comfy.model_patcher.CoreModelPatcher

View File

@@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
@@ -1023,7 +1023,11 @@ class Anima(supported_models_base.BASE):
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Anima(self, device=device)
@@ -1034,12 +1038,6 @@ class Anima(supported_models_base.BASE):
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
if dtype is torch.float16:
self.memory_usage_factor *= 1.4
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",

View File

@@ -16,7 +16,6 @@ def sample_manual_loop_no_classes(
temperature: float = 0.85,
top_p: float = 0.9,
top_k: int = None,
min_p: float = 0.000,
seed: int = 1,
min_tokens: int = 1,
max_new_tokens: int = 2048,
@@ -24,8 +23,6 @@ def sample_manual_loop_no_classes(
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
if ids is None:
return []
device = model.execution_device
if execution_dtype is None:
@@ -35,7 +32,6 @@ def sample_manual_loop_no_classes(
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
@@ -45,27 +41,22 @@ def sample_manual_loop_no_classes(
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
for step in range(max_new_tokens):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
if cfg_scale != 1.0:
cond_logits = next_token_logits[0:1]
uncond_logits = next_token_logits[1:2]
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
else:
cfg_logits = next_token_logits[0:1]
cond_logits = next_token_logits[0:1]
uncond_logits = next_token_logits[1:2]
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
if use_eos_score:
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
@@ -73,7 +64,7 @@ def sample_manual_loop_no_classes(
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
if use_eos_score:
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
@@ -81,12 +72,6 @@ def sample_manual_loop_no_classes(
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
if min_p is not None and min_p > 0:
probs = torch.softmax(cfg_logits, dim=-1)
p_max = probs.max(dim=-1, keepdim=True).values
indices_to_remove = probs < (min_p * p_max)
cfg_logits[indices_to_remove] = remove_logit_value
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -108,8 +93,8 @@ def sample_manual_loop_no_classes(
break
embed, _, _, _ = model.process_tokens([[token]], device)
embeds = embed.repeat(embeds_batch, 1, 1)
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
embeds = embed.repeat(2, 1, 1)
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
@@ -117,31 +102,24 @@ def sample_manual_loop_no_classes(
return output_audio_codes
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
positive = [[token for token, _ in inner_list] for inner_list in positive]
negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
negative = negative[0]
if cfg_scale != 1.0:
negative = [[token for token, _ in inner_list] for inner_list in negative]
negative = negative[0]
neg_pad = 0
if len(negative) < len(positive):
neg_pad = (len(positive) - len(negative))
negative = [model.special_tokens["pad"]] * neg_pad + negative
neg_pad = 0
if len(negative) < len(positive):
neg_pad = (len(positive) - len(negative))
negative = [model.special_tokens["pad"]] * neg_pad + negative
pos_pad = 0
if len(negative) > len(positive):
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
pos_pad = 0
if len(negative) > len(positive):
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
ids = [positive, negative]
else:
paddings = []
ids = [positive]
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
paddings = [pos_pad, neg_pad]
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -151,12 +129,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
for k in ("bpm", "duration", "keyscale", "timesignature")
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature[:-2]
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
@@ -169,11 +147,8 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
use_keys = ("bpm", "timesignature", "keyscale", "duration")
use_keys = ("bpm", "duration", "keyscale", "timesignature")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
user_metas["timesignature"] = timesignature[:-2]
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
@@ -184,13 +159,9 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
text = text.strip()
text_negative = kwargs.get("caption_negative", text).strip()
out = {}
lyrics = kwargs.get("lyrics", "")
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
duration = kwargs.get("duration", 120)
if isinstance(duration, str):
duration = float(duration.split(None, 1)[0])
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
@@ -199,55 +170,28 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
min_p = kwargs.get("min_p", 0.000)
duration = math.ceil(duration)
kwargs["duration"] = duration
tokens_duration = duration * 5
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
metas_negative = {
k.rsplit("_", 1)[0]: kwargs.pop(k)
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
if k in kwargs
}
if not kwargs.get("use_negative_caption"):
_ = metas_negative.pop("caption", None)
cot_text = self._metas_to_cot(caption=text, **kwargs)
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
cot_text = self._metas_to_cot(caption = text, **kwargs)
meta_cap = self._metas_to_cap(**kwargs)
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
llm_prompts = {
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
}
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
out = {
prompt_key: self.qwen3_06b.tokenize_with_weights(
prompt,
prompt_key == "qwen3_06b" and return_word_ids,
disable_weights = True,
**kwargs,
)
for prompt_key, prompt in llm_prompts.items()
}
out["lm_metadata"] = {"min_tokens": min_tokens,
"max_tokens": max_tokens,
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
out["lm_metadata"] = {"min_tokens": duration * 5,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
}
return out
@@ -308,7 +252,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
out["audio_codes"] = [audio_codes]
return base_out, None, out

View File

@@ -23,7 +23,7 @@ class AnimaTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out

View File

@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@@ -27,7 +27,6 @@ from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram
import json
@@ -1156,32 +1155,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
@@ -1403,21 +1376,3 @@ def string_to_seed(data):
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def deepcopy_list_dict(obj, memo=None):
if memo is None:
memo = {}
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if isinstance(obj, dict):
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
elif isinstance(obj, list):
res = [deepcopy_list_dict(i, memo) for i in obj]
else:
res = obj
memo[obj_id] = res
return res

View File

@@ -21,7 +21,6 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
@@ -182,21 +181,18 @@ class BypassForwardHook:
)
return # Already injected
# Move adapter weights to compute device (GPU)
# Use get_torch_device() instead of module.weight.device because
# with offloading, module weights may be on CPU while compute happens on GPU
device = comfy.model_management.get_torch_device()
# Get dtype from module weight if available
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
device = None
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
# Only use dtype if it's a standard float type, not quantized
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = None
self._move_adapter_weights_to_device(device, dtype)
if device is not None:
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward

View File

@@ -34,21 +34,6 @@ class VideoInput(ABC):
"""
pass
@abstractmethod
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = False,
) -> VideoInput | None:
"""
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
Returns:
A new VideoInput, or None if the result would have negative duration
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without

View File

@@ -6,7 +6,6 @@ from typing import Optional
from .._input import AudioInput, VideoInput
import av
import io
import itertools
import json
import numpy as np
import math
@@ -30,6 +29,7 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
@@ -57,14 +57,12 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
self.__start_time = start_time
self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -98,16 +96,6 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
raw_duration = self._get_raw_duration()
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
if self.__duration:
return min(self.__duration, duration_from_start)
return duration_from_start
def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
@@ -125,13 +113,9 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for packet in frame_iterator:
frame_count += 1
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
@@ -147,54 +131,36 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available and usable
if (
video_stream.frames
and video_stream.frames > 0
and not self.__start_time
and not self.__duration
):
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
raw_duration = float(video_stream.duration * video_stream.time_base)
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
duration_seconds = min(self.__duration, duration_from_start)
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
frame_count = 1
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for frame in frame_iterator:
if frame.pts >= start_pts:
break
else:
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
for frame in frame_iterator:
if frame.pts >= end_pts:
break
frame_count += 1
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -233,21 +199,9 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
video_stream = self._get_first_video_stream(container)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
# Get video frames
frames = []
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
@@ -255,44 +209,31 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
container.seek(start_pts, stream=video_stream)
# Use last stream for consistency
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
audio_frames = []
resample = av.audio.resampler.AudioResampler(format='fltp').resample
frames = itertools.chain.from_iterable(
map(resample, container.decode(audio_stream))
)
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
if has_first_frame:
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
if self.__duration:
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
})
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -309,7 +250,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -321,14 +262,15 @@ class VideoFromFile(VideoInput):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if self.__start_time or self.__duration:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path, format=format, codec=codec, metadata=metadata
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
@@ -362,21 +304,10 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
if len(container.streams.video):
return container.streams.video[0]
raise ValueError(f"No video stream found in file '{self.__file}'")
def as_trimmed(
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
) -> VideoInput | None:
trimmed = VideoFromFile(
self.get_stream_source(),
start_time=start_time + self.__start_time,
duration=duration,
)
if trimmed.get_duration() < duration and strict_duration:
return None
return trimmed
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput):
@@ -391,7 +322,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate,
frame_rate=self.__components.frame_rate
)
def save_to(
@@ -399,7 +330,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@@ -426,10 +357,7 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -444,21 +372,12 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = True,
) -> VideoInput | None:
if self.get_duration() < start_time + duration:
return None
#TODO Consider tracking duration and trimming at time of save?
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)

View File

@@ -1197,6 +1197,12 @@ class KlingImageGenImageReferenceType(str, Enum):
face = 'face'
class KlingImageGenModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_5 = 'kling-v1-5'
kling_v2 = 'kling-v2'
class KlingImageGenerationsRequest(BaseModel):
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
callback_url: Optional[AnyUrl] = Field(
@@ -1212,7 +1218,7 @@ class KlingImageGenerationsRequest(BaseModel):
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
)
image_reference: Optional[KlingImageGenImageReferenceType] = None
model_name: str = Field(...)
model_name: Optional[KlingImageGenModelName] = 'kling-v1'
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
negative_prompt: Optional[str] = Field(
None, description='Negative text prompt', max_length=200

View File

@@ -1,22 +1,12 @@
from pydantic import BaseModel, Field
class MultiPromptEntry(BaseModel):
index: int = Field(...)
prompt: str = Field(...)
duration: str = Field(...)
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
sound: str = Field(..., description="'on' or 'off'")
class OmniParamImage(BaseModel):
@@ -36,10 +26,6 @@ class OmniProFirstLastFrameRequest(BaseModel):
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str | None = Field(None, description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class OmniProReferences2VideoRequest(BaseModel):
@@ -52,10 +38,6 @@ class OmniProReferences2VideoRequest(BaseModel):
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str | None = Field(None, description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class TaskStatusVideoResult(BaseModel):
@@ -72,7 +54,6 @@ class TaskStatusImageResult(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
series_images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusResponseData(BaseModel):
@@ -96,42 +77,31 @@ class OmniImageParamImage(BaseModel):
class OmniProImageRequest(BaseModel):
model_name: str = Field(...)
resolution: str = Field(...)
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
result_type: str | None = Field(None, description="Set to 'series' for series generation")
series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(...)
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(...)
prompt: str | None = Field(...)
negative_prompt: str | None = Field(None)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(...)
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
image_tail: str | None = Field(None)
duration: str = Field(...)
prompt: str | None = Field(...)
negative_prompt: str | None = Field(None)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
multi_shot: bool | None = Field(None)
multi_prompt: list[MultiPromptEntry] | None = Field(None)
shot_type: str | None = Field(None)
class MotionControlRequest(BaseModel):

File diff suppressed because it is too large Load Diff

View File

@@ -30,30 +30,6 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
)
_EUR_TO_USD = 1.19
def _tier_price_eur(megapixels: float) -> float:
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
if megapixels <= 1.3:
return 0.143
if megapixels <= 3.0:
return 0.286
if megapixels <= 6.4:
return 0.429
return 1.716
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
num_steps = int(math.log2(scale))
total_eur = 0.0
pixels = width * height
for _ in range(num_steps):
total_eur += _tier_price_eur(pixels / 1_000_000)
pixels *= 4
return round(total_eur * _EUR_TO_USD, 2)
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod
@@ -127,20 +103,11 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr="""
(
$ad := widgets.auto_downscale;
$mins := $ad
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
)
""",
),
@@ -201,10 +168,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
final_height, final_width = get_image_dimensions(image)
actual_scale = int(scale_factor.rstrip("x"))
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@@ -226,7 +189,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -295,14 +257,8 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr="""
(
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
)
""",
),
@@ -365,9 +321,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
final_height, final_width = get_image_dimensions(image)
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@@ -386,7 +339,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -925,8 +877,8 @@ class MagnificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
MagnificImageUpscalerCreativeNode,
MagnificImageUpscalerPreciseV2Node,
# MagnificImageUpscalerCreativeNode,
# MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode,
MagnificImageRelightNode,
MagnificImageSkinEnhancerNode,

View File

@@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
default=33,
min=1,
max=100,
step=1,
tooltip="Number of denoising steps",
@@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=60,
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
default=33,
min=1,
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
@@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
steps=60,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
@@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
default=33,
min=1,
max=100,
step=1,
tooltip="Inference steps",

View File

@@ -143,9 +143,9 @@ async def poll_op(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 10,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 1.4,
retry_backoff_per_poll: float = 2.0,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -240,9 +240,9 @@ async def poll_op_raw(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 10,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 1.4,
retry_backoff_per_poll: float = 2.0,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,

View File

@@ -20,60 +20,10 @@ class JobStatus:
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
def has_3d_extension(filename: str) -> bool:
lower = filename.lower()
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
def normalize_output_item(item):
"""Normalize a single output list item for the jobs API.
Returns the normalized item, or None to exclude it.
String items with 3D extensions become {filename, type, subfolder} dicts.
"""
if item is None:
return None
if isinstance(item, str):
if has_3d_extension(item):
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
return None
if isinstance(item, dict):
return item
return None
def normalize_outputs(outputs: dict) -> dict:
"""Normalize raw node outputs for the jobs API.
Transforms string 3D filenames into file output dicts and removes
None items. All other items (non-3D strings, dicts, etc.) are
preserved as-is.
"""
normalized = {}
for node_id, node_outputs in outputs.items():
if not isinstance(node_outputs, dict):
normalized[node_id] = node_outputs
continue
normalized_node = {}
for media_type, items in node_outputs.items():
if media_type == 'animated' or not isinstance(items, list):
normalized_node[media_type] = items
continue
normalized_items = []
for item in items:
if item is None:
continue
norm = normalize_output_item(item)
normalized_items.append(norm if norm is not None else item)
normalized_node[media_type] = normalized_items
normalized[node_id] = normalized_node
return normalized
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
@@ -95,9 +45,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
Maintains backwards compatibility with existing logic.
Priority:
1. media_type is 'images', 'video', 'audio', or '3d'
1. media_type is 'images', 'video', or 'audio'
2. format field starts with 'video/' or 'audio/'
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
@@ -189,7 +139,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
})
if include_outputs:
job['outputs'] = normalize_outputs(outputs)
job['outputs'] = outputs
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
@@ -221,23 +171,18 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
normalized = normalize_output_item(item)
if normalized is None:
continue
count += 1
if preview_output is not None:
if not isinstance(item, dict):
continue
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
if preview_output is None and is_previewable(media_type, item):
enriched = {
**normalized,
**item,
'nodeId': node_id,
'mediaType': media_type
}
if 'mediaType' not in normalized:
enriched['mediaType'] = media_type
if normalized.get('type') == 'output':
if item.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched

View File

@@ -164,7 +164,21 @@ class WebUIProgressHandler(ProgressHandler):
if self.server_instance is None:
return
active_nodes = self.registry.get_serialized_state()
# Only send info for non-pending nodes
active_nodes = {
node_id: {
"value": state["value"],
"max": state["max"],
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
}
for node_id, state in nodes.items()
if state["state"] != NodeState.Pending
}
# Send a combined progress_state message with all node states
# Include client_id to ensure message is only sent to the initiating client
@@ -300,24 +314,6 @@ class ProgressRegistry:
if handler.enabled:
handler.finish_handler(node_id, entry, self.prompt_id)
def get_serialized_state(self) -> Dict[str, dict]:
"""Return current node progress as a dict suitable for WS progress_state."""
active: Dict[str, dict] = {}
for nid, state in self.nodes.items():
if state["state"] == NodeState.Pending:
continue
active[nid] = {
"value": state["value"],
"max": state["max"],
"state": state["state"].value,
"node_id": nid,
"prompt_id": self.prompt_id,
"display_node_id": self.dynprompt.get_display_node_id(nid),
"parent_node_id": self.dynprompt.get_parent_node_id(nid),
"real_node_id": self.dynprompt.get_real_node_id(nid),
}
return active
def reset_handlers(self) -> None:
"""Reset all handlers"""
for handler in self.handlers.values():

View File

@@ -49,14 +49,13 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)

View File

@@ -622,7 +622,6 @@ class SamplerSASolver(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerSASolver",
search_aliases=["sde"],
category="sampling/custom_sampling/samplers",
inputs=[
io.Model.Input("model"),
@@ -667,7 +666,6 @@ class SamplerSEEDS2(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
search_aliases=["sde", "exp heun"],
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),

View File

@@ -9,14 +9,6 @@ if TYPE_CHECKING:
from uuid import UUID
def _extract_tensor(data, output_channels):
"""Extract tensor from data, handling both single tensors and lists."""
if isinstance(data, list):
# LTX2 AV tensors: [video, audio]
return data[0][:, :output_channels], data[1][:, :output_channels]
return data[:, :output_channels], None
def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args
transformer_options: dict[str] = args[-1]
@@ -25,7 +17,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if not transformer_options:
transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"]
x, ax = _extract_tensor(args[0], easycache.output_channels)
x: torch.Tensor = args[0][:, :easycache.output_channels]
sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@@ -43,11 +35,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if easycache.skip_current_step and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
return easycache.apply_cache_diff(x, uuids)
if easycache.initial_step:
easycache.first_cond_uuid = uuids[0]
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
@@ -63,18 +51,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
result = easycache.apply_cache_diff(x, uuids)
if ax is not None:
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
return [result, result_audio]
return result
return easycache.apply_cache_diff(x, uuids)
else:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
full_output: torch.Tensor = executor(*args, **kwargs)
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
output: torch.Tensor = executor(*args, **kwargs)
if has_first_cond_uuid and easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
@@ -91,15 +74,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev, uuids)
if audio_output is not None:
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
if has_first_cond_uuid:
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return full_output
return output
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
@@ -108,8 +89,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
x: torch.Tensor = args[0][:, :easycache.output_channels]
# prepare next x_prev
x: torch.Tensor = args[0][:, :easycache.output_channels]
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
@@ -216,7 +197,6 @@ class EasyCacheHolder:
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
@@ -265,21 +245,20 @@ class EasyCacheHolder:
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
if self.first_cond_uuid in uuids and not is_audio:
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
self.total_steps_skipped += 1
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# slice out only what is relevant to this cond
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_this_dim = True
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
if skip_this_dim:
skip_this_dim = False
continue
@@ -291,11 +270,10 @@ class EasyCacheHolder:
else:
slicing.append(slice(None))
batch_slice = batch_slice + slicing
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if output.shape[1:] != x.shape[1:]:
if not self.allow_mismatch:
@@ -315,7 +293,7 @@ class EasyCacheHolder:
diff = output - x
batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids
@@ -346,8 +324,6 @@ class EasyCacheHolder:
self.output_prev_norm = None
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
del self.uuid_cache_diffs_audio
self.uuid_cache_diffs_audio = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self

View File

@@ -391,9 +391,8 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
dims = list(range(1, latent_vector_magnitude.ndim))
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 5 + mean) * multiplier

View File

@@ -4,7 +4,6 @@ import os
import numpy as np
import safetensors
import torch
import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont
@@ -28,11 +27,6 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample(
self,
noise,
@@ -51,11 +45,9 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape,
self.conds,
self.model_options,
force_full_load=not self.offloading,
force_offload=self.offloading,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
)
)
torch.cuda.empty_cache()
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -412,97 +404,16 @@ def find_all_highest_child_module_with_forward(
return result
def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
@@ -1025,18 +936,6 @@ class TrainLoraNode(io.ComfyNode):
default=True,
tooltip="Use gradient checkpointing for training.",
),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
@@ -1083,8 +982,6 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype,
algorithm,
gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora,
bucket_mode,
bypass_mode,
@@ -1103,8 +1000,6 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
@@ -1159,18 +1054,16 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing
if gradient_checkpointing:
modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model, depth=checkpoint_depth
)
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=not offloading
[mp], memory_required=1e20, force_full_load=True
)
torch.cuda.empty_cache()
@@ -1207,7 +1100,7 @@ class TrainLoraNode(io.ComfyNode):
)
# Setup guider
guider = TrainGuider(mp, offloading=offloading)
guider = TrainGuider(mp)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
@@ -1220,7 +1113,6 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop
try:
comfy.model_management.in_training = True
_run_training_loop(
guider,
train_sampler,
@@ -1231,7 +1123,6 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
@@ -1241,20 +1132,19 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m)
del train_sampler, optimizer
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
del adapter
del all_weight_adapters
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):
class LoraModelLoader(io.ComfyNode):#
@classmethod
def define_schema(cls):
return io.Schema(
@@ -1276,11 +1166,6 @@ class LoraModelLoader(io.ComfyNode):
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
],
outputs=[
io.Model.Output(
@@ -1290,18 +1175,13 @@ class LoraModelLoader(io.ComfyNode):
)
@classmethod
def execute(cls, model, lora, strength_model, bypass=False):
def execute(cls, model, lora, strength_model):
if strength_model == 0:
return io.NodeOutput(model)
if bypass:
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)

View File

@@ -202,56 +202,6 @@ class LoadVideo(io.ComfyNode):
return True
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Video Slice",
display_name="Video Slice",
search_aliases=[
"trim video duration",
"skip first frames",
"frame load cap",
"start time",
],
category="image/video",
inputs=[
io.Video.Input("video"),
io.Float.Input(
"start_time",
default=0.0,
max=1e5,
min=-1e5,
step=0.001,
tooltip="Start time in seconds",
),
io.Float.Input(
"duration",
default=0.0,
min=0.0,
step=0.001,
tooltip="Duration in seconds, or 0 for unlimited duration",
),
io.Boolean.Input(
"strict_duration",
default=False,
tooltip="If True, when the specified duration is not possible, an error will be raised.",
),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
if trimmed is not None:
return io.NodeOutput(trimmed)
raise ValueError(
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
)
class VideoExtension(ComfyExtension):
@override
@@ -262,7 +212,6 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
VideoSlice,
]
async def comfy_entrypoint() -> VideoExtension:

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.13.0"
__version__ = "0.12.3"

View File

@@ -13,11 +13,8 @@ from contextlib import nullcontext
import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
@@ -530,10 +527,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
torch.cuda.synchronize()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
@@ -701,8 +696,6 @@ class PromptExecutor:
else:
self.server.client_id = None
self.server.current_prompt_id = prompt_id
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
@@ -724,13 +717,10 @@ class PromptExecutor:
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
self.server.current_cached_nodes = cached_nodes
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
self.server.current_executed_nodes = executed
self.server.current_outputs_to_execute = list(execute_outputs)
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
@@ -767,10 +757,6 @@ class PromptExecutor:
"meta": meta_outputs,
}
self.server.last_node_id = None
self.server.current_prompt_id = None
self.server.current_outputs_to_execute = None
self.server.current_cached_nodes = None
self.server.current_executed_nodes = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.13.0"
version = "0.12.3"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,6 +1,6 @@
comfyui-frontend-package==1.38.13
comfyui-workflow-templates==0.8.38
comfyui-embedded-docs==0.4.1
comfyui-workflow-templates==0.8.31
comfyui-embedded-docs==0.4.0
torch
torchsde
torchvision
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.1.8
comfy-aimdo>=0.1.7
requests
#non essential dependencies:

View File

@@ -242,10 +242,6 @@ class PromptServer():
self.routes = routes
self.last_node_id = None
self.client_id = None
self.current_prompt_id = None
self.current_outputs_to_execute = None
self.current_cached_nodes = None
self.current_executed_nodes = None
self.on_prompt_handlers = []
@@ -268,38 +264,9 @@ class PromptServer():
try:
# Send initial state to the new client
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client, replay catch-up events
if self.client_id == sid and self.current_prompt_id is not None:
await self.send("execution_start", {
"prompt_id": self.current_prompt_id,
"timestamp": int(time.time() * 1000),
"outputs_to_execute": self.current_outputs_to_execute or [],
"executed_node_ids": list(self.current_executed_nodes) if self.current_executed_nodes else [],
}, sid)
if self.current_cached_nodes:
await self.send("execution_cached", {
"nodes": self.current_cached_nodes,
"prompt_id": self.current_prompt_id,
"timestamp": int(time.time() * 1000),
}, sid)
from comfy_execution.progress import get_progress_state
progress = get_progress_state()
if progress.prompt_id == self.current_prompt_id:
active_nodes = progress.get_serialized_state()
if active_nodes:
await self.send("progress_state", {
"prompt_id": self.current_prompt_id,
"nodes": active_nodes,
}, sid)
if self.last_node_id is not None:
await self.send("executing", {
"node": self.last_node_id,
"display_node": self.last_node_id,
"prompt_id": self.current_prompt_id,
}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
# Flag to track if we've received the first message
first_message = True

View File

@@ -5,11 +5,8 @@ from comfy_execution.jobs import (
is_previewable,
normalize_queue_item,
normalize_history_item,
normalize_output_item,
normalize_outputs,
get_outputs_summary,
apply_sorting,
has_3d_extension,
)
@@ -38,8 +35,8 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio, 3d media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d']:
"""Images, video, audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
@@ -49,7 +46,7 @@ class TestIsPreviewable:
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
@@ -163,7 +160,7 @@ class TestGetOutputsSummary:
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
@@ -195,64 +192,6 @@ class TestGetOutputsSummary:
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
def test_string_3d_filename_creates_preview(self):
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
Only the .glb counts — nulls and non-file strings are excluded."""
outputs = {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 1
assert preview is not None
assert preview['filename'] == 'preview3d_abc123.glb'
assert preview['mediaType'] == '3d'
assert preview['nodeId'] == 'node1'
assert preview['type'] == 'output'
def test_string_non_3d_filename_no_preview(self):
"""String items without 3D extensions should not create a preview."""
outputs = {
'node1': {
'result': ['data.json', None]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 0
assert preview is None
def test_string_3d_filename_used_as_fallback(self):
"""String 3D preview should be used when no dict items are previewable."""
outputs = {
'node1': {
'latents': [{'filename': 'latent.safetensors'}],
},
'node2': {
'result': ['model.glb', None]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None
assert preview['filename'] == 'model.glb'
assert preview['mediaType'] == '3d'
class TestHas3DExtension:
"""Unit tests for has_3d_extension()"""
def test_recognized_extensions(self):
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
assert has_3d_extension(f'model{ext}') is True
def test_case_insensitive(self):
assert has_3d_extension('MODEL.GLB') is True
assert has_3d_extension('Scene.GLTF') is True
def test_non_3d_extensions(self):
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
assert has_3d_extension(name) is False
class TestApplySorting:
"""Unit tests for apply_sorting()"""
@@ -456,142 +395,3 @@ class TestNormalizeHistoryItem:
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}
def test_include_outputs_normalizes_3d_strings(self):
"""Detail view should transform string 3D filenames into file output dicts."""
history_item = {
'prompt': (
5,
'prompt-3d',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'result': ['preview3d_abc123.glb', None, None]
}
},
}
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
assert job['outputs_count'] == 1
result_items = job['outputs']['node1']['result']
assert len(result_items) == 1
assert result_items[0] == {
'filename': 'preview3d_abc123.glb',
'type': 'output',
'subfolder': '',
'mediaType': '3d',
}
def test_include_outputs_preserves_dict_items(self):
"""Detail view normalization should pass dict items through unchanged."""
history_item = {
'prompt': (
5,
'prompt-img',
{'nodes': {}},
{'create_time': 1234567890},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {
'node1': {
'images': [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
}
},
}
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
assert job['outputs_count'] == 1
assert job['outputs']['node1']['images'] == [
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
]
class TestNormalizeOutputItem:
"""Unit tests for normalize_output_item()"""
def test_none_returns_none(self):
assert normalize_output_item(None) is None
def test_string_3d_extension_synthesizes_dict(self):
result = normalize_output_item('model.glb')
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
def test_string_non_3d_extension_returns_none(self):
assert normalize_output_item('data.json') is None
def test_string_no_extension_returns_none(self):
assert normalize_output_item('camera_info_string') is None
def test_dict_passes_through(self):
item = {'filename': 'test.png', 'type': 'output'}
assert normalize_output_item(item) is item
def test_other_types_return_none(self):
assert normalize_output_item(42) is None
assert normalize_output_item(True) is None
class TestNormalizeOutputs:
"""Unit tests for normalize_outputs()"""
def test_empty_outputs(self):
assert normalize_outputs({}) == {}
def test_dict_items_pass_through(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
}
}
result = normalize_outputs(outputs)
assert result == outputs
def test_3d_string_synthesized(self):
outputs = {
'node1': {
'result': ['model.glb', None, None],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': [
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
],
}
}
def test_animated_key_preserved(self):
outputs = {
'node1': {
'images': [{'filename': 'a.png', 'type': 'output'}],
'animated': [True],
}
}
result = normalize_outputs(outputs)
assert result['node1']['animated'] == [True]
def test_non_dict_node_outputs_preserved(self):
outputs = {'node1': 'unexpected_value'}
result = normalize_outputs(outputs)
assert result == {'node1': 'unexpected_value'}
def test_none_items_filtered_but_other_types_preserved(self):
outputs = {
'node1': {
'result': ['data.json', None, [1, 2, 3]],
}
}
result = normalize_outputs(outputs)
assert result == {
'node1': {
'result': ['data.json', [1, 2, 3]],
}
}