Compare commits

...

11 Commits

Author SHA1 Message Date
dante01yoon
2a6e3dc7a8 Add isfinite guard, exception chaining, and unit tests for Number Convert node
- Add math.isfinite() check to prevent int() crash on inf/nan string inputs
- Use 'from None' for cleaner exception chaining on string parse failure
- Add 21 unit tests covering all input types and error paths
2026-03-19 10:41:02 +09:00
dante01yoon
dae107e430 Register nodes_number_convert.py in extras_files list
Without this entry in nodes.py, the Number Convert node file
would not be discovered and loaded at startup.
2026-03-18 19:56:43 +09:00
dante01yoon
82cf5d88c2 Add Number Convert node for unified numeric type conversion
Consolidates fragmented IntToFloat/FloatToInt nodes (previously only
available via third-party packs like ComfyMath, FillNodes, etc.) into
a single core node.

- Single input accepting INT, FLOAT, STRING, and BOOL types
- Two outputs: FLOAT and INT
- Conversion: bool→0/1, string→parsed number, float↔int standard cast
- Follows Math Expression node patterns (comfy_api, io.Schema, etc.)

Refs: COM-16925
2026-03-18 19:56:43 +09:00
Anton Bukov
b941913f1d fix: run text encoders on MPS GPU instead of CPU for Apple Silicon (#12809)
On Apple Silicon, `vram_state` is set to `VRAMState.SHARED` because
CPU and GPU share unified memory. However, `text_encoder_device()`
only checked for `HIGH_VRAM` and `NORMAL_VRAM`, causing all text
encoders to fall back to CPU on MPS devices.

Adding `VRAMState.SHARED` to the condition allows non-quantized text
encoders (e.g. bf16 Gemma 3 12B) to run on the MPS GPU, providing
significant speedup for text encoding and prompt generation.

Note: quantized models (fp4/fp8) that use float8_e4m3fn internally
will still fall back to CPU via the `supports_cast()` check in
`CLIP.__init__()`, since MPS does not support fp8 dtypes.
2026-03-17 21:21:32 -04:00
rattus
cad24ce262 cascade: remove dead weight init code (#13026)
This weight init process is fully shadowed be the weight load and
doesnt work in dynamic_vram were the weight allocation is deferred.
2026-03-17 20:59:10 -04:00
comfyanonymous
68d542cc06 Fix case where pixel space VAE could cause issues. (#13030) 2026-03-17 20:46:22 -04:00
Jukka Seppänen
735a0465e5 Inplace VAE output processing to reduce peak RAM consumption. (#13028) 2026-03-17 20:20:49 -04:00
Dr.Lt.Data
8b9d039f26 bump manager version to 4.1b6 (#13022) 2026-03-17 18:17:03 -04:00
rattus
035414ede4 Reduce WAN VAE VRAM, Save use cases for OOM/Tiler (#13014)
* wan: vae: encoder: Add feature cache layer that corks singles

If a downsample only gives you a single frame, save it to the feature
cache and return nothing to the top level. This increases the
efficiency of cacheability, but also prepares support for going two
by two rather than four by four on the frames.

* wan: remove all concatentation with the feature cache

The loopers are now responsible for ensuring that non-final frames are
processes at least two-by-two, elimiating the need for this cat case.

* wan: vae: recurse and chunk for 2+2 frames on decode

Avoid having to clone off slices of 4 frame chunks and reduce the size
of the big 6 frame convolutions down to 4. Save the VRAMs.

* wan: encode frames 2x2.

Reduce VRAM usage greatly by encoding frames 2 at a time rather than
4.

* wan: vae: remove cloning

The loopers now control the chunking such there is noever more than 2
frames, so just cache these slices directly and avoid the clone
allocations completely.

* wan: vae: free consumer caller tensors on recursion

* wan: vae: restyle a little to match LTX
2026-03-17 17:34:39 -04:00
rattus
1a157e1f97 Reduce LTX VAE VRAM usage and save use cases from OOMs/Tiler (#13013)
* ltx: vae: scale the chunk size with the users VRAM

Scale this linearly down for users with low VRAM.

* ltx: vae: free non-chunking recursive intermediates

* ltx: vae: cleanup some intermediates

The conv layer can be the VRAM peak and it does a torch.cat. So cleanup
the pieces of the cat. Also clear our the cache ASAP as each layer detect
its end as this VAE surges in VRAM at the end due to the ended padding
increasing the size of the final frame convolutions off-the-books to
the chunker. So if all the earlier layers free up their cache it can
offset that surge.

Its a fragmentation nightmare, and the chance of it having to recache the
pyt allocator is very high, but you wont OOM.
2026-03-17 17:32:43 -04:00
Christian Byrne
ed7c2c6579 Mark weight_dtype as advanced input in Load Diffusion Model node (#12769)
Mark the weight_dtype parameter in UNETLoader (Load Diffusion Model) as
an advanced input to reduce UI complexity for new users. The parameter
is now hidden behind an expandable Advanced section, matching the
pattern used for other advanced inputs like device, tile_size, and
overlap.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbaf1-d3c0-718e-a325-318baba86dec
2026-03-17 07:24:00 -07:00
10 changed files with 329 additions and 122 deletions

View File

@@ -136,16 +136,7 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

View File

@@ -65,9 +65,13 @@ class CausalConv3d(nn.Module):
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

View File

@@ -297,7 +297,23 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
class Decoder(nn.Module):
r"""
@@ -525,8 +541,11 @@ class Decoder(nn.Module):
timestep_shift_scale = ada_values.unbind(dim=1)
output = []
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample, ended):
def run_up(idx, sample_ref, ended):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
@@ -554,13 +573,21 @@ class Decoder(nn.Module):
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
run_up(idx + 1, next_sample_ref, ended)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
run_up(0, sample, True)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
run_up(0, [sample], True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)

View File

@@ -99,7 +99,7 @@ class Resample(nn.Module):
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
@@ -109,22 +109,7 @@ class Resample(nn.Module):
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
@@ -145,19 +130,24 @@ class Resample(nn.Module):
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
feat_cache[idx] = x
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
cache_x = x[:, :, -1:, :, :]
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
deferred_x = feat_cache[idx + 1]
if deferred_x is not None:
x = torch.cat([deferred_x, x], 2)
feat_cache[idx + 1] = None
if x.shape[2] == 1 and not final:
feat_cache[idx + 1] = x
x = None
feat_idx[0] += 2
return x
@@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
def forward(self, x):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
@@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x = layer(x, feat_cache, feat_idx, final=final)
if x is None:
return None
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, final=final)
else:
x = layer(x)
@@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -393,14 +364,7 @@ class Decoder3d(nn.Module):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -409,42 +373,56 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
out_chunks = []
def run_up(layer_idx, x_ref, feat_idx):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_idx.copy(),
)
del x
return
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
return x
next_x_ref = [x]
del x
run_up(layer_idx + 1, next_x_ref, feat_idx)
run_up(0, [x], feat_idx)
return out_chunks
def count_conv3d(model):
def count_cache_layers(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
count += 1
return count
@@ -482,11 +460,12 @@ class WanVAE(nn.Module):
conv_idx = [0]
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
t = 1 + ((t - 1) // 4) * 4
iter_ = 1 + (t - 1) // 2
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.encoder)
## 对encode输入的x按时间拆分为1、4、4、4....
feat_map = [None] * count_cache_layers(self.encoder)
## 对encode输入的x按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
for i in range(iter_):
conv_idx = [0]
if i == 0:
@@ -496,20 +475,23 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_idx=conv_idx,
final=(i == (iter_ - 1)))
if out_ is None:
continue
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
return mu
def decode(self, z):
conv_idx = [0]
# z: [b,c,t,h,w]
iter_ = z.shape[2]
iter_ = 1 + z.shape[2] // 2
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
feat_map = [None] * count_cache_layers(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
@@ -520,8 +502,8 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
return out
out += out_
return torch.cat(out, 2)

View File

@@ -1003,7 +1003,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:

View File

@@ -455,7 +455,7 @@ class VAE:
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
@@ -952,8 +952,8 @@ class VAE:
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out

View File

@@ -0,0 +1,79 @@
"""Number Convert node for unified numeric type conversion.
Provides a single node that converts INT, FLOAT, STRING, and BOOL
inputs into FLOAT and INT outputs.
"""
from __future__ import annotations
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class NumberConvertNode(io.ComfyNode):
"""Converts various types to numeric FLOAT and INT outputs."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ComfyNumberConvert",
display_name="Number Convert",
category="math",
search_aliases=[
"int to float", "float to int", "number convert",
"int2float", "float2int", "cast", "parse number",
"string to number", "bool to int",
],
inputs=[
io.MultiType.Input(
"value",
[io.Int, io.Float, io.String, io.Boolean],
display_name="value",
),
],
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
],
)
@classmethod
def execute(cls, value) -> io.NodeOutput:
if isinstance(value, bool):
float_val = 1.0 if value else 0.0
elif isinstance(value, (int, float)):
float_val = float(value)
elif isinstance(value, str):
text = value.strip()
if not text:
raise ValueError("Cannot convert empty string to number.")
try:
float_val = float(text)
except ValueError:
raise ValueError(
f"Cannot convert string to number: {value!r}"
) from None
else:
raise TypeError(
f"Unsupported input type: {type(value).__name__}"
)
if not math.isfinite(float_val):
raise ValueError(
f"Cannot convert non-finite value to number: {float_val}"
)
return io.NodeOutput(float_val, int(float_val))
class NumberConvertExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [NumberConvertNode]
async def comfy_entrypoint() -> NumberConvertExtension:
return NumberConvertExtension()

View File

@@ -1 +1 @@
comfyui_manager==4.1b5
comfyui_manager==4.1b6

View File

@@ -952,7 +952,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
@@ -2452,6 +2452,7 @@ async def init_builtin_extra_nodes():
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
"nodes_number_convert.py",
"nodes_painter.py",
]

View File

@@ -0,0 +1,123 @@
import pytest
from unittest.mock import patch, MagicMock
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
mock_server = MagicMock()
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
from comfy_extras.nodes_number_convert import NumberConvertNode
class TestNumberConvertExecute:
@staticmethod
def _exec(value) -> object:
return NumberConvertNode.execute(value)
# --- INT input ---
def test_int_input(self):
result = self._exec(42)
assert result[0] == 42.0
assert result[1] == 42
def test_int_zero(self):
result = self._exec(0)
assert result[0] == 0.0
assert result[1] == 0
def test_int_negative(self):
result = self._exec(-7)
assert result[0] == -7.0
assert result[1] == -7
# --- FLOAT input ---
def test_float_input(self):
result = self._exec(3.14)
assert result[0] == 3.14
assert result[1] == 3
def test_float_truncation_toward_zero(self):
result = self._exec(-2.9)
assert result[0] == -2.9
assert result[1] == -2 # int() truncates toward zero, not floor
def test_float_output_type(self):
result = self._exec(5)
assert isinstance(result[0], float)
def test_int_output_type(self):
result = self._exec(5.7)
assert isinstance(result[1], int)
# --- BOOL input ---
def test_bool_true(self):
result = self._exec(True)
assert result[0] == 1.0
assert result[1] == 1
def test_bool_false(self):
result = self._exec(False)
assert result[0] == 0.0
assert result[1] == 0
# --- STRING input ---
def test_string_integer(self):
result = self._exec("42")
assert result[0] == 42.0
assert result[1] == 42
def test_string_float(self):
result = self._exec("3.14")
assert result[0] == 3.14
assert result[1] == 3
def test_string_negative(self):
result = self._exec("-5.5")
assert result[0] == -5.5
assert result[1] == -5
def test_string_with_whitespace(self):
result = self._exec(" 7.0 ")
assert result[0] == 7.0
assert result[1] == 7
def test_string_scientific_notation(self):
result = self._exec("1e3")
assert result[0] == 1000.0
assert result[1] == 1000
# --- STRING error paths ---
def test_empty_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert empty string"):
self._exec("")
def test_whitespace_only_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert empty string"):
self._exec(" ")
def test_non_numeric_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert string to number"):
self._exec("abc")
def test_string_inf_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("inf")
def test_string_nan_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("nan")
def test_string_negative_inf_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("-inf")
# --- Unsupported type ---
def test_unsupported_type_raises(self):
with pytest.raises(TypeError, match="Unsupported input type"):
self._exec([1, 2, 3])