mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-19 12:07:31 +00:00
Compare commits
6 Commits
feature/nu
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6b869d7d3 | ||
|
|
56ff88f951 | ||
|
|
9fff091f35 | ||
|
|
dcd659590f | ||
|
|
b67ed2a45f | ||
|
|
06957022d4 |
@@ -473,6 +473,17 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||
ts, hs, ws, to = 1, 1, 1, 0
|
||||
for block in self.up_blocks:
|
||||
if isinstance(block, DepthToSpaceUpsample):
|
||||
ts *= block.stride[0]
|
||||
hs *= block.stride[1]
|
||||
ws *= block.stride[2]
|
||||
if block.stride[0] > 1:
|
||||
to = to * block.stride[0] + 1
|
||||
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
@@ -494,11 +505,15 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def decode_output_shape(self, input_shape):
|
||||
c, (ts, hs, ws), to = self._output_scale
|
||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
output_buffer: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
@@ -540,7 +555,13 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
output = []
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
|
||||
def run_up(idx, sample_ref, ended):
|
||||
@@ -556,7 +577,10 @@ class Decoder(nn.Module):
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
output_offset[0] += t
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
@@ -588,11 +612,8 @@ class Decoder(nn.Module):
|
||||
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, [sample], True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
return output_buffer
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
@@ -1226,7 +1247,10 @@ class VideoVAE(nn.Module):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||
|
||||
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return samples
|
||||
|
||||
19
comfy/sd.py
19
comfy/sd.py
@@ -951,12 +951,23 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if hasattr(self.first_stage_model, 'decode_output_shape'):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
|
||||
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
|
||||
@@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
thought: bool | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
|
||||
@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$isFlash := $contains($m, "nano banana 2");
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
@@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if (part.thought is True) != thought:
|
||||
continue
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
@@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Number Convert node for unified numeric type conversion.
|
||||
|
||||
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||
inputs into FLOAT and INT outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class NumberConvertNode(io.ComfyNode):
|
||||
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ComfyNumberConvert",
|
||||
display_name="Number Convert",
|
||||
category="math",
|
||||
search_aliases=[
|
||||
"int to float", "float to int", "number convert",
|
||||
"int2float", "float2int", "cast", "parse number",
|
||||
"string to number", "bool to int",
|
||||
],
|
||||
inputs=[
|
||||
io.MultiType.Input(
|
||||
"value",
|
||||
[io.Int, io.Float, io.String, io.Boolean],
|
||||
display_name="value",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
if isinstance(value, bool):
|
||||
float_val = 1.0 if value else 0.0
|
||||
elif isinstance(value, (int, float)):
|
||||
float_val = float(value)
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise ValueError("Cannot convert empty string to number.")
|
||||
try:
|
||||
float_val = float(text)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Cannot convert string to number: {value!r}"
|
||||
) from None
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported input type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
|
||||
return io.NodeOutput(float_val, int(float_val))
|
||||
|
||||
|
||||
class NumberConvertExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [NumberConvertNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||
return NumberConvertExtension()
|
||||
1
nodes.py
1
nodes.py
@@ -2452,7 +2452,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-frontend-package==1.41.21
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||
|
||||
|
||||
class TestNumberConvertExecute:
|
||||
@staticmethod
|
||||
def _exec(value) -> object:
|
||||
return NumberConvertNode.execute(value)
|
||||
|
||||
# --- INT input ---
|
||||
|
||||
def test_int_input(self):
|
||||
result = self._exec(42)
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_int_zero(self):
|
||||
result = self._exec(0)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
def test_int_negative(self):
|
||||
result = self._exec(-7)
|
||||
assert result[0] == -7.0
|
||||
assert result[1] == -7
|
||||
|
||||
# --- FLOAT input ---
|
||||
|
||||
def test_float_input(self):
|
||||
result = self._exec(3.14)
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_float_truncation_toward_zero(self):
|
||||
result = self._exec(-2.9)
|
||||
assert result[0] == -2.9
|
||||
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||
|
||||
def test_float_output_type(self):
|
||||
result = self._exec(5)
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
def test_int_output_type(self):
|
||||
result = self._exec(5.7)
|
||||
assert isinstance(result[1], int)
|
||||
|
||||
# --- BOOL input ---
|
||||
|
||||
def test_bool_true(self):
|
||||
result = self._exec(True)
|
||||
assert result[0] == 1.0
|
||||
assert result[1] == 1
|
||||
|
||||
def test_bool_false(self):
|
||||
result = self._exec(False)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
# --- STRING input ---
|
||||
|
||||
def test_string_integer(self):
|
||||
result = self._exec("42")
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_string_float(self):
|
||||
result = self._exec("3.14")
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative(self):
|
||||
result = self._exec("-5.5")
|
||||
assert result[0] == -5.5
|
||||
assert result[1] == -5
|
||||
|
||||
def test_string_with_whitespace(self):
|
||||
result = self._exec(" 7.0 ")
|
||||
assert result[0] == 7.0
|
||||
assert result[1] == 7
|
||||
|
||||
def test_string_scientific_notation(self):
|
||||
result = self._exec("1e3")
|
||||
assert result[0] == 1000.0
|
||||
assert result[1] == 1000
|
||||
|
||||
# --- STRING error paths ---
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec("")
|
||||
|
||||
def test_whitespace_only_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec(" ")
|
||||
|
||||
def test_non_numeric_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||
self._exec("abc")
|
||||
|
||||
def test_string_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("inf")
|
||||
|
||||
def test_string_nan_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("nan")
|
||||
|
||||
def test_string_negative_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("-inf")
|
||||
|
||||
# --- Unsupported type ---
|
||||
|
||||
def test_unsupported_type_raises(self):
|
||||
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||
self._exec([1, 2, 3])
|
||||
Reference in New Issue
Block a user