mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-19 20:17:32 +00:00
Compare commits
1 Commits
master
...
feature/ha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb385b4504 |
@@ -23,11 +23,6 @@ class CausalConv3d(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if isinstance(stride, int):
|
||||
self.time_stride = stride
|
||||
else:
|
||||
self.time_stride = stride[0]
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
@@ -63,23 +58,18 @@ class CausalConv3d(nn.Module):
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
input_length = sum([piece.shape[2] for piece in pieces])
|
||||
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and cache_length == 0:
|
||||
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
needs_caching = False
|
||||
if needs_caching and x.shape[2] >= cache_length:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
del pieces
|
||||
del cached
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
elif is_end:
|
||||
self.temporal_cache_state[tid] = (None, True)
|
||||
|
||||
|
||||
@@ -233,7 +233,10 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -244,14 +247,10 @@ class Encoder(nn.Module):
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
@@ -283,35 +282,9 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||
|
||||
outputs = []
|
||||
samples = [sample[:, :, :1, :, :]]
|
||||
if sample.shape[2] > 1:
|
||||
chunk_t = max(2, max_chunk_size // frame_size)
|
||||
if chunk_t < 4:
|
||||
chunk_t = 2
|
||||
elif chunk_t < 8:
|
||||
chunk_t = 4
|
||||
else:
|
||||
chunk_t = (chunk_t // 8) * 8
|
||||
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
|
||||
for chunk_idx, chunk in enumerate(samples):
|
||||
if chunk_idx == len(samples) - 1:
|
||||
mark_conv3d_ended(self)
|
||||
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||
output = self._forward_chunk(chunk)
|
||||
if output is not None:
|
||||
outputs.append(output)
|
||||
|
||||
return torch_cat_if_needed(outputs, dim=2)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
@@ -500,17 +473,6 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||
ts, hs, ws, to = 1, 1, 1, 0
|
||||
for block in self.up_blocks:
|
||||
if isinstance(block, DepthToSpaceUpsample):
|
||||
ts *= block.stride[0]
|
||||
hs *= block.stride[1]
|
||||
ws *= block.stride[2]
|
||||
if block.stride[0] > 1:
|
||||
to = to * block.stride[0] + 1
|
||||
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
@@ -532,15 +494,11 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def decode_output_shape(self, input_shape):
|
||||
c, (ts, hs, ws), to = self._output_scale
|
||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
output_buffer: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
@@ -582,13 +540,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
|
||||
output = []
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
|
||||
def run_up(idx, sample_ref, ended):
|
||||
@@ -604,10 +556,7 @@ class Decoder(nn.Module):
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
output_offset[0] += t
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
@@ -639,8 +588,11 @@ class Decoder(nn.Module):
|
||||
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, [sample], True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
return output_buffer
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
@@ -764,25 +716,12 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
tid = threading.get_ident()
|
||||
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||
if cached_input is not None:
|
||||
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||
cached_input = None
|
||||
|
||||
if self.stride[0] == 2 and pad_first:
|
||||
if self.stride[0] == 2:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
pad_first = False
|
||||
|
||||
if x.shape[2] < self.stride[0]:
|
||||
cached_input = x
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
return None
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
@@ -797,26 +736,15 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||
if cached_x is not None:
|
||||
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||
cached_x = None
|
||||
else:
|
||||
cached_x = x
|
||||
x = None
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
if x is not None:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
x = x + x_in
|
||||
|
||||
return x
|
||||
|
||||
@@ -1149,8 +1077,6 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
comfy_has_chunked_io = True
|
||||
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -1293,15 +1219,14 @@ class VideoVAE(nn.Module):
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x, device=None):
|
||||
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
def decode(self, x):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
@@ -39,10 +39,7 @@ def read_tensor_file_slice_into(tensor, destination):
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
or destination.numel() * destination.element_size() < info.size):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
|
||||
@@ -1003,7 +1003,7 @@ def text_encoder_offload_device():
|
||||
def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
|
||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
|
||||
if should_use_fp16(prioritize_performance=False):
|
||||
return get_torch_device()
|
||||
else:
|
||||
|
||||
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
|
||||
28
comfy/sd.py
28
comfy/sd.py
@@ -951,23 +951,12 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
@@ -1038,13 +1027,8 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
|
||||
@@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = output[b:b+1].zero_()
|
||||
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
@@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
@@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
out.div_(out_div)
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
|
||||
@@ -1353,6 +1353,7 @@ class NodeInfoV1:
|
||||
python_module: Any=None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
has_intermediate_output: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
dev_only: bool=None
|
||||
@@ -1465,6 +1466,16 @@ class Schema:
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
|
||||
"""
|
||||
has_intermediate_output: bool=False
|
||||
"""Flags this node as having intermediate output that should persist across page refreshes.
|
||||
|
||||
Nodes with this flag behave like output nodes (their UI results are cached and resent
|
||||
to the frontend) but do NOT automatically get added to the execution list. This means
|
||||
they will only execute if they are on the dependency path of a real output node.
|
||||
|
||||
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
|
||||
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
|
||||
"""
|
||||
is_deprecated: bool=False
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
is_experimental: bool=False
|
||||
@@ -1582,6 +1593,7 @@ class Schema:
|
||||
category=self.category,
|
||||
description=self.description,
|
||||
output_node=self.is_output_node,
|
||||
has_intermediate_output=self.has_intermediate_output,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
dev_only=self.is_dev_only,
|
||||
@@ -1873,6 +1885,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._OUTPUT_NODE
|
||||
|
||||
_HAS_INTERMEDIATE_OUTPUT = None
|
||||
@final
|
||||
@classproperty
|
||||
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._HAS_INTERMEDIATE_OUTPUT
|
||||
|
||||
_INPUT_IS_LIST = None
|
||||
@final
|
||||
@classproperty
|
||||
@@ -1965,6 +1985,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._API_NODE = schema.is_api_node
|
||||
if cls._OUTPUT_NODE is None:
|
||||
cls._OUTPUT_NODE = schema.is_output_node
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
|
||||
if cls._INPUT_IS_LIST is None:
|
||||
cls._INPUT_IS_LIST = schema.is_input_list
|
||||
if cls._NOT_IDEMPOTENT is None:
|
||||
|
||||
@@ -118,6 +118,13 @@ class TopologicalSort:
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return get_input_info(class_def, input_name)
|
||||
|
||||
def is_intermediate_output(self, node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS.get(class_type)
|
||||
if class_def is None:
|
||||
return False
|
||||
return hasattr(class_def, 'HAS_INTERMEDIATE_OUTPUT') and class_def.HAS_INTERMEDIATE_OUTPUT == True
|
||||
|
||||
def make_input_strong_link(self, to_node_id, to_input):
|
||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||
if to_input not in inputs:
|
||||
@@ -129,7 +136,7 @@ class TopologicalSort:
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if not self.is_cached(from_node_id):
|
||||
if not self.is_cached(from_node_id) or self.is_intermediate_output(from_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
@@ -159,7 +166,7 @@ class TopologicalSort:
|
||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy):
|
||||
if not self.is_cached(from_node_id):
|
||||
if not self.is_cached(from_node_id) or self.is_intermediate_output(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
|
||||
@@ -277,6 +284,8 @@ class ExecutionList(TopologicalSort):
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
return True
|
||||
if hasattr(class_def, 'HAS_INTERMEDIATE_OUTPUT') and class_def.HAS_INTERMEDIATE_OUTPUT == True:
|
||||
return True
|
||||
return False
|
||||
|
||||
# If an available node is async, do that first.
|
||||
|
||||
@@ -59,6 +59,7 @@ class ImageCropV2(IO.ComfyNode):
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
|
||||
@@ -30,6 +30,7 @@ class PainterNode(io.ComfyNode):
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.41.21
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
|
||||
@@ -709,6 +709,9 @@ class PromptServer():
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'HAS_INTERMEDIATE_OUTPUT') and obj_class.HAS_INTERMEDIATE_OUTPUT == True:
|
||||
info['has_intermediate_output'] = True
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
|
||||
|
||||
Reference in New Issue
Block a user