mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-25 15:07:31 +00:00
Compare commits
4 Commits
deepme987/
...
mark-dtype
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
782f09da4b | ||
|
|
cc6e9052af | ||
|
|
102eb9f99d | ||
|
|
d62334eb9e |
@@ -93,50 +93,6 @@ class IndexListCallbacks:
|
||||
return {}
|
||||
|
||||
|
||||
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
|
||||
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
|
||||
return None
|
||||
cond_tensor = cond_value.cond
|
||||
if temporal_dim >= cond_tensor.ndim:
|
||||
return None
|
||||
|
||||
cond_size = cond_tensor.size(temporal_dim)
|
||||
|
||||
if temporal_scale == 1:
|
||||
expected_size = x_in.size(window.dim) - temporal_offset
|
||||
if cond_size != expected_size:
|
||||
return None
|
||||
|
||||
if temporal_offset == 0 and temporal_scale == 1:
|
||||
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
if temporal_scale > 1:
|
||||
scaled = []
|
||||
for i in indices:
|
||||
for k in range(temporal_scale):
|
||||
si = i * temporal_scale + k
|
||||
if si < cond_size:
|
||||
scaled.append(si)
|
||||
indices = scaled
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
idx = tuple([slice(None)] * temporal_dim + [indices])
|
||||
sliced = cond_tensor[idx].to(device)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@@ -221,17 +177,10 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if not handled and self._model is not None:
|
||||
result = self._model.resize_cond_for_context_window(
|
||||
cond_key, cond_value, window, x_in, device,
|
||||
retain_index_list=self.cond_retain_index_list)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
@@ -275,7 +224,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
@@ -136,7 +136,16 @@ class ResBlock(nn.Module):
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
@@ -23,11 +23,6 @@ class CausalConv3d(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if isinstance(stride, int):
|
||||
self.time_stride = stride
|
||||
else:
|
||||
self.time_stride = stride[0]
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
@@ -63,25 +58,16 @@ class CausalConv3d(nn.Module):
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
input_length = sum([piece.shape[2] for piece in pieces])
|
||||
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and cache_length == 0:
|
||||
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
needs_caching = False
|
||||
if needs_caching and x.shape[2] >= cache_length:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
del pieces
|
||||
del cached
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
elif is_end:
|
||||
self.temporal_cache_state[tid] = (None, True)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
|
||||
@@ -233,7 +233,10 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -244,14 +247,10 @@ class Encoder(nn.Module):
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
@@ -283,35 +282,9 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||
|
||||
outputs = []
|
||||
samples = [sample[:, :, :1, :, :]]
|
||||
if sample.shape[2] > 1:
|
||||
chunk_t = max(2, max_chunk_size // frame_size)
|
||||
if chunk_t < 4:
|
||||
chunk_t = 2
|
||||
elif chunk_t < 8:
|
||||
chunk_t = 4
|
||||
else:
|
||||
chunk_t = (chunk_t // 8) * 8
|
||||
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
|
||||
for chunk_idx, chunk in enumerate(samples):
|
||||
if chunk_idx == len(samples) - 1:
|
||||
mark_conv3d_ended(self)
|
||||
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||
output = self._forward_chunk(chunk)
|
||||
if output is not None:
|
||||
outputs.append(output)
|
||||
|
||||
return torch_cat_if_needed(outputs, dim=2)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
@@ -324,23 +297,7 @@ class Encoder(nn.Module):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||
|
||||
def get_max_chunk_size(device: torch.device) -> int:
|
||||
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||
|
||||
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||
return MIN_CHUNK_SIZE
|
||||
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||
return MAX_CHUNK_SIZE
|
||||
|
||||
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||
)
|
||||
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -500,17 +457,6 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||
ts, hs, ws, to = 1, 1, 1, 0
|
||||
for block in self.up_blocks:
|
||||
if isinstance(block, DepthToSpaceUpsample):
|
||||
ts *= block.stride[0]
|
||||
hs *= block.stride[1]
|
||||
ws *= block.stride[2]
|
||||
if block.stride[0] > 1:
|
||||
to = to * block.stride[0] + 1
|
||||
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
@@ -532,62 +478,11 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def decode_output_shape(self, input_shape):
|
||||
c, (ts, hs, ws), to = self._output_scale
|
||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||
|
||||
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
||||
sample = sample_ref[0]
|
||||
sample_ref[0] = None
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
output_offset[0] += t
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if ended:
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||
|
||||
if num_chunks == 1:
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
output_buffer: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
@@ -602,7 +497,6 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
timestep_shift_scale = None
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
@@ -630,18 +524,48 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
output = []
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
return
|
||||
|
||||
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
return output_buffer
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
@@ -765,25 +689,12 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
tid = threading.get_ident()
|
||||
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||
if cached_input is not None:
|
||||
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||
cached_input = None
|
||||
|
||||
if self.stride[0] == 2 and pad_first:
|
||||
if self.stride[0] == 2:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
pad_first = False
|
||||
|
||||
if x.shape[2] < self.stride[0]:
|
||||
cached_input = x
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
return None
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
@@ -798,26 +709,15 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||
if cached_x is not None:
|
||||
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||
cached_x = None
|
||||
else:
|
||||
cached_x = x
|
||||
x = None
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
if x is not None:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
x = x + x_in
|
||||
|
||||
return x
|
||||
|
||||
@@ -1150,8 +1050,6 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
comfy_has_chunked_io = True
|
||||
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -1294,15 +1192,14 @@ class VideoVAE(nn.Module):
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x, device=None):
|
||||
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
def decode(self, x):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
@@ -99,7 +99,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
@@ -109,7 +109,22 @@ class Resample(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
@@ -130,24 +145,19 @@ class Resample(nn.Module):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :]
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
|
||||
deferred_x = feat_cache[idx + 1]
|
||||
if deferred_x is not None:
|
||||
x = torch.cat([deferred_x, x], 2)
|
||||
feat_cache[idx + 1] = None
|
||||
|
||||
if x.shape[2] == 1 and not final:
|
||||
feat_cache[idx + 1] = x
|
||||
x = None
|
||||
|
||||
feat_idx[0] += 2
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
@@ -167,12 +177,19 @@ class ResidualBlock(nn.Module):
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -196,7 +213,7 @@ class AttentionBlock(nn.Module):
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
@@ -266,10 +283,17 @@ class Encoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -279,16 +303,14 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if x is None:
|
||||
return None
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -296,7 +318,14 @@ class Encoder3d(nn.Module):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -360,48 +389,18 @@ class Decoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
|
||||
x = x_ref[0]
|
||||
x_ref[0] = None
|
||||
if layer_idx >= len(self.upsamples):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[feat_idx[0]])
|
||||
feat_cache[feat_idx[0]] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
out_chunks.append(x)
|
||||
return
|
||||
|
||||
layer = self.upsamples[layer_idx]
|
||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||||
for frame_idx in range(x.shape[2]):
|
||||
self.run_up(
|
||||
layer_idx,
|
||||
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
||||
feat_cache,
|
||||
feat_idx.copy(),
|
||||
out_chunks,
|
||||
)
|
||||
del x
|
||||
return
|
||||
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
next_x_ref = [x]
|
||||
del x
|
||||
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -410,21 +409,42 @@ class Decoder3d(nn.Module):
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
out_chunks = []
|
||||
|
||||
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
|
||||
return out_chunks
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_cache_layers(model):
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -462,12 +482,11 @@ class WanVAE(nn.Module):
|
||||
conv_idx = [0]
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
t = 1 + ((t - 1) // 4) * 4
|
||||
iter_ = 1 + (t - 1) // 2
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_cache_layers(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
@@ -477,23 +496,20 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
final=(i == (iter_ - 1)))
|
||||
if out_ is None:
|
||||
continue
|
||||
feat_idx=conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
# z: [b,c,t,h,w]
|
||||
iter_ = 1 + z.shape[2] // 2
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_cache_layers(self.decoder)
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -504,8 +520,8 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
out += out_
|
||||
return torch.cat(out, 2)
|
||||
out = torch.cat([out, out_], 2)
|
||||
return out
|
||||
|
||||
@@ -39,10 +39,7 @@ def read_tensor_file_slice_into(tensor, destination):
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
or destination.numel() * destination.element_size() < info.size):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
|
||||
@@ -285,12 +285,6 @@ class BaseModel(torch.nn.Module):
|
||||
return data
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||
Return a sliced cond object, or None to fall through to default handling.
|
||||
Use comfy.context_windows.slice_cond() for common cases."""
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@@ -1381,12 +1375,6 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "vace_context":
|
||||
import comfy.context_windows
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
@@ -1439,12 +1427,6 @@ class WAN21_HuMo(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
import comfy.context_windows
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_Animate(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
||||
@@ -1462,14 +1444,6 @@ class WAN22_Animate(WAN21):
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
import comfy.context_windows
|
||||
if cond_key == "face_pixel_values":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
|
||||
if cond_key == "pose_latents":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
@@ -1506,12 +1480,6 @@ class WAN22_S2V(WAN21):
|
||||
out['reference_motion'] = reference_motion.shape
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
import comfy.context_windows
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
|
||||
34
comfy/sd.py
34
comfy/sd.py
@@ -455,7 +455,7 @@ class VAE:
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
@@ -951,23 +951,12 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
@@ -978,7 +967,6 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
@@ -1039,13 +1027,8 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
@@ -1060,7 +1043,6 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
|
||||
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
|
||||
@@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = output[b:b+1].zero_()
|
||||
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
@@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
@@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
out.div_(out_div)
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
|
||||
@@ -15,7 +15,6 @@ from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
from comfy.cli_args import args
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class ComfyAPI_latest(ComfyAPIBase):
|
||||
@@ -26,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
self.environment = self.Environment()
|
||||
self.caching = self.Caching()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
@@ -87,27 +85,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
class Environment(ProxiedSingleton):
|
||||
"""
|
||||
Query the current execution environment.
|
||||
|
||||
Managed deployments set the ``COMFY_EXECUTION_ENVIRONMENT`` env var
|
||||
so custom nodes can adapt their behaviour at runtime.
|
||||
|
||||
Example::
|
||||
|
||||
from comfy_api.latest import api
|
||||
|
||||
env = api.environment.get() # "local" | "cloud" | "remote"
|
||||
"""
|
||||
|
||||
_VALID = {"local", "cloud", "remote"}
|
||||
|
||||
async def get(self) -> str:
|
||||
"""Return the execution environment: ``"local"``, ``"cloud"``, or ``"remote"``."""
|
||||
value = os.environ.get("COMFY_EXECUTION_ENVIRONMENT", "local").lower().strip()
|
||||
return value if value in self._VALID else "local"
|
||||
|
||||
class Caching(ProxiedSingleton):
|
||||
"""
|
||||
External cache provider API for sharing cached node outputs
|
||||
|
||||
@@ -67,7 +67,6 @@ class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
thought: bool | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
|
||||
@@ -47,10 +47,6 @@ SEEDREAM_MODELS = {
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
@@ -139,7 +135,6 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03}""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -947,7 +942,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
]
|
||||
return await process_video_task(
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -957,12 +952,6 @@ async def process_video_task(
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
|
||||
@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$isFlash := $contains($m, "nano banana 2");
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
||||
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
@@ -188,12 +188,10 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if (part.thought is True) != thought:
|
||||
continue
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
@@ -933,11 +931,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -999,11 +992,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
||||
@@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b6
|
||||
comfyui_manager==4.1b5
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.41.21
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
|
||||
Reference in New Issue
Block a user