mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-23 07:59:07 +00:00
Compare commits
20 Commits
deepme987/
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cdc0d523f | ||
|
|
749d5b4e8d | ||
|
|
e988df72f8 | ||
|
|
0be87b082a | ||
|
|
ec4b1659ab | ||
|
|
cb388e2912 | ||
|
|
9949c19c63 | ||
|
|
cc6f9500a1 | ||
|
|
db85cf03ff | ||
|
|
91e1f45d80 | ||
|
|
6045c11d8b | ||
|
|
529c80255f | ||
|
|
43a1263b60 | ||
|
|
102773cd2c | ||
|
|
1e1d4f1254 | ||
|
|
eb22225387 | ||
|
|
b38dd0ff23 | ||
|
|
ad94d47221 | ||
|
|
e75f775ae8 | ||
|
|
c514890325 |
@@ -4,9 +4,6 @@ import math
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils as utils
|
||||
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
@@ -43,30 +40,6 @@ class AudioVAEComponentConfig:
|
||||
|
||||
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||
|
||||
|
||||
class ModelDeviceManager:
|
||||
"""Manages device placement and GPU residency for the composed model."""
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||
|
||||
def ensure_model_loaded(self) -> None:
|
||||
comfy.model_management.free_memory(
|
||||
self.patcher.model_size(),
|
||||
self.patcher.load_device,
|
||||
)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
|
||||
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(self.patcher.load_device)
|
||||
|
||||
@property
|
||||
def load_device(self):
|
||||
return self.patcher.load_device
|
||||
|
||||
|
||||
class AudioLatentNormalizer:
|
||||
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||
|
||||
@@ -132,23 +105,17 @@ class AudioPreprocessor:
|
||||
class AudioVAE(torch.nn.Module):
|
||||
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||
|
||||
def __init__(self, state_dict: dict, metadata: dict):
|
||||
def __init__(self, metadata: dict):
|
||||
super().__init__()
|
||||
|
||||
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||
|
||||
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
if "bwe" in component_config.vocoder:
|
||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||
else:
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
AudioPatchifier(
|
||||
@@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module):
|
||||
n_fft=autoencoder_config["n_fft"],
|
||||
)
|
||||
|
||||
self.device_manager = ModelDeviceManager(self)
|
||||
|
||||
def encode(self, audio: dict) -> torch.Tensor:
|
||||
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
|
||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||
|
||||
waveform = audio["waveform"]
|
||||
waveform_sample_rate = audio["sample_rate"]
|
||||
waveform = audio
|
||||
waveform_sample_rate = sample_rate
|
||||
input_device = waveform.device
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
if waveform.shape[1] == 1:
|
||||
@@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module):
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
waveform, waveform_sample_rate, device=waveform.device
|
||||
)
|
||||
|
||||
latents = self.autoencoder.encode(mel_spec)
|
||||
@@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module):
|
||||
"""Decode normalized latent tensors into an audio waveform."""
|
||||
original_shape = latents.shape
|
||||
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
latents = self.device_manager.move_to_load_device(latents)
|
||||
latents = self.normalizer.denormalize(latents)
|
||||
|
||||
target_shape = self.target_shape_from_latents(original_shape)
|
||||
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||
|
||||
waveform = self.run_vocoder(mel_spec)
|
||||
return self.device_manager.move_to_load_device(waveform)
|
||||
return waveform
|
||||
|
||||
def target_shape_from_latents(self, latents_shape):
|
||||
batch, _, time, _ = latents_shape
|
||||
|
||||
596
comfy/ldm/sam3/detector.py
Normal file
596
comfy/ldm/sam3/detector.py
Normal file
@@ -0,0 +1,596 @@
|
||||
# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.ops import roi_align
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker
|
||||
from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__
|
||||
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine
|
||||
|
||||
TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker}
|
||||
from comfy.ops import cast_to_input
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
cx, cy, w, h = x.unbind(-1)
|
||||
return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
|
||||
|
||||
|
||||
def gen_sineembed_for_position(pos_tensor, num_feats=256):
|
||||
"""Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats)."""
|
||||
assert num_feats % 2 == 0
|
||||
hdim = num_feats // 2
|
||||
freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim)
|
||||
embeds = []
|
||||
for c in range(pos_tensor.shape[-1]):
|
||||
raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs
|
||||
embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2))
|
||||
return torch.cat(embeds, dim=-1).to(pos_tensor.dtype)
|
||||
|
||||
|
||||
class SplitMHA(nn.Module):
|
||||
"""Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight)."""
|
||||
def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, q_input, k_input=None, v_input=None, mask=None):
|
||||
q = self.q_proj(q_input)
|
||||
if k_input is None:
|
||||
k = self.k_proj(q_input)
|
||||
v = self.v_proj(q_input)
|
||||
else:
|
||||
k = self.k_proj(k_input)
|
||||
v = self.v_proj(v_input if v_input is not None else k_input)
|
||||
if mask is not None and mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
|
||||
dtype = q.dtype # manual_cast may produce mixed dtypes
|
||||
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class MLPWithNorm(nn.Module):
|
||||
"""MLP with residual connection and output LayerNorm."""
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
|
||||
self.layers = nn.ModuleList([
|
||||
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype)
|
||||
self.residual = residual and (input_dim == output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
orig = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i < len(self.layers) - 1:
|
||||
x = F.relu(x)
|
||||
if self.residual:
|
||||
x = x + orig
|
||||
return self.out_norm(x)
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, pos, text_memory=None, text_mask=None):
|
||||
normed = self.norm1(x)
|
||||
q_k = normed + pos
|
||||
x = x + self.self_attn(q_k, q_k, normed)
|
||||
if text_memory is not None:
|
||||
normed = self.norm2(x)
|
||||
x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask)
|
||||
normed = self.norm3(x)
|
||||
x = x + self.linear2(F.relu(self.linear1(normed)))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""Checkpoint: transformer.encoder.layers.N.*"""
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, pos, text_memory=None, text_mask=None):
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos, text_memory, text_mask)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None):
|
||||
q_k = x + x_pos
|
||||
x = self.norm2(x + self.self_attn(q_k, q_k, x))
|
||||
if text_memory is not None:
|
||||
x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask))
|
||||
x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias))
|
||||
x = self.norm3(x + self.linear2(F.relu(self.linear1(x))))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6,
|
||||
num_queries=200, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype)
|
||||
self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes
|
||||
self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each)
|
||||
self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
||||
self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
|
||||
self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _inverse_sigmoid(x):
|
||||
return torch.log(x / (1 - x + 1e-6) + 1e-6)
|
||||
|
||||
def _compute_box_rpb(self, ref_points, H, W):
|
||||
"""Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias."""
|
||||
boxes_xyxy = box_cxcywh_to_xyxy(ref_points)
|
||||
B, Q, _ = boxes_xyxy.shape
|
||||
coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H
|
||||
coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W
|
||||
deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2]
|
||||
deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2]
|
||||
|
||||
log2_8 = float(math.log2(8))
|
||||
def log_scale(d):
|
||||
return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8
|
||||
|
||||
rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype))
|
||||
rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype))
|
||||
|
||||
bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2)
|
||||
pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype)
|
||||
return torch.cat([pres_bias, bias], dim=2)
|
||||
|
||||
def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72):
|
||||
B = memory.shape[0]
|
||||
tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1)
|
||||
presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1)
|
||||
ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid()
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model))
|
||||
tgt_with_pres = torch.cat([presence_out, tgt], dim=1)
|
||||
pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1)
|
||||
tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos,
|
||||
text_memory, text_mask, self._compute_box_rpb(ref_points, H, W))
|
||||
presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:]
|
||||
if layer_idx < len(self.layers) - 1:
|
||||
ref_inv = self._inverse_sigmoid(ref_points)
|
||||
ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach()
|
||||
|
||||
query_out = self.norm(tgt)
|
||||
ref_inv = self._inverse_sigmoid(ref_points)
|
||||
boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid()
|
||||
presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1)
|
||||
return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence}
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6,
|
||||
num_queries=200, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations)
|
||||
self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
class GeometryEncoder(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.roi_size = roi_size
|
||||
self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
|
||||
self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype)
|
||||
self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype)
|
||||
self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype)
|
||||
self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype)
|
||||
self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype)
|
||||
self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.encode = nn.ModuleList([
|
||||
EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
|
||||
def _encode_points(self, coords, labels, img_feat_2d):
|
||||
"""Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized."""
|
||||
B, N, _ = coords.shape
|
||||
embed = self.points_direct_project(coords)
|
||||
# Pool features from backbone at point locations via grid_sample
|
||||
grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1]
|
||||
sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1]
|
||||
embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C]
|
||||
# Positional encoding of coordinates
|
||||
x, y = coords[:, :, 0], coords[:, :, 1] # [B, N]
|
||||
pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
|
||||
enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1)
|
||||
embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed))
|
||||
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
|
||||
return embed
|
||||
|
||||
def _encode_boxes(self, boxes, labels, img_feat_2d):
|
||||
"""Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh."""
|
||||
B, N, _ = boxes.shape
|
||||
embed = self.boxes_direct_project(boxes)
|
||||
# ROI align from backbone at box regions
|
||||
H, W = img_feat_2d.shape[-2:]
|
||||
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
|
||||
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
|
||||
boxes_scaled = boxes_xyxy * scale
|
||||
sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size)
|
||||
proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C]
|
||||
embed = embed + proj
|
||||
# Positional encoding of box center + size
|
||||
cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3]
|
||||
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
|
||||
enc = enc.view(B, N, -1)
|
||||
embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed))
|
||||
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
|
||||
return embed
|
||||
|
||||
def forward(self, points=None, boxes=None, image_features=None):
|
||||
"""Encode geometry prompts. image_features: [B, HW, C] flattened backbone features."""
|
||||
# Prepare 2D image features for pooling
|
||||
img_feat_2d = None
|
||||
if image_features is not None:
|
||||
B = image_features.shape[0]
|
||||
HW, C = image_features.shape[1], image_features.shape[2]
|
||||
hw = int(math.sqrt(HW))
|
||||
img_normed = self.img_pre_norm(image_features)
|
||||
img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw)
|
||||
|
||||
embeddings = []
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
embeddings.append(self._encode_points(coords, labels, img_feat_2d))
|
||||
if boxes is not None:
|
||||
B = boxes.shape[0]
|
||||
box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device)
|
||||
embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d))
|
||||
if not embeddings:
|
||||
return None
|
||||
geo = torch.cat(embeddings, dim=1)
|
||||
geo = self.norm(geo)
|
||||
if image_features is not None:
|
||||
for layer in self.encode:
|
||||
geo = layer(geo, torch.zeros_like(geo), image_features)
|
||||
geo = self.encode_norm(geo)
|
||||
return self.final_proj(geo)
|
||||
|
||||
|
||||
class PixelDecoder(nn.Module):
|
||||
"""Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation."""
|
||||
def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)])
|
||||
self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)])
|
||||
|
||||
def forward(self, backbone_features):
|
||||
prev = backbone_features[-1]
|
||||
for i, feat in enumerate(backbone_features[:-1][::-1]):
|
||||
prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest"))))
|
||||
return prev
|
||||
|
||||
|
||||
class MaskPredictor(nn.Module):
|
||||
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, query_embeddings, pixel_features):
|
||||
mask_embed = self.mask_embed(query_embeddings)
|
||||
return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features)
|
||||
|
||||
|
||||
class SegmentationHead(nn.Module):
|
||||
def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations)
|
||||
self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype)
|
||||
self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None):
|
||||
if encoder_hidden_states is not None and prompt is not None:
|
||||
enc_normed = self.cross_attn_norm(encoder_hidden_states)
|
||||
enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask)
|
||||
encoder_hidden_states = enc_cross + encoder_hidden_states
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1]
|
||||
encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W)
|
||||
backbone_features = list(backbone_features)
|
||||
backbone_features[-1] = encoder_visual
|
||||
|
||||
pixel_features = self.pixel_decoder(backbone_features)
|
||||
instance_features = self.instance_seg_head(pixel_features)
|
||||
masks = self.mask_predictor(query_embeddings, instance_features)
|
||||
return masks
|
||||
|
||||
|
||||
class DotProductScoring(nn.Module):
|
||||
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
||||
self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.scale = 1.0 / (d_model ** 0.5)
|
||||
|
||||
def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None):
|
||||
prompt = self.prompt_mlp(prompt_embeddings)
|
||||
if prompt_mask is not None:
|
||||
weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype)
|
||||
pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1)
|
||||
else:
|
||||
pooled = prompt.mean(dim=1)
|
||||
hs = self.hs_proj(query_embeddings)
|
||||
pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype)
|
||||
scores = torch.matmul(hs, pp)
|
||||
return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1)
|
||||
|
||||
|
||||
class SAM3Detector(nn.Module):
|
||||
def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
image_model = kwargs.pop("image_model", "SAM3")
|
||||
for k in ("num_heads", "num_head_channels"):
|
||||
kwargs.pop(k, None)
|
||||
multiplex = image_model == "SAM31"
|
||||
# SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0)
|
||||
self.scalp = 0 if multiplex else 1
|
||||
self.backbone = nn.ModuleDict({
|
||||
"vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs),
|
||||
"language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}),
|
||||
})
|
||||
self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations)
|
||||
self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def _get_backbone_features(self, images):
|
||||
"""Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions)."""
|
||||
bb = self.backbone["vision_backbone"]
|
||||
if bb.multiplex:
|
||||
all_f, all_p, tf, tp = bb(images, tracker_mode="propagation")
|
||||
else:
|
||||
all_f, all_p, tf, tp = bb(images, need_tracker=True)
|
||||
return all_f, all_p, tf, tp
|
||||
|
||||
@staticmethod
|
||||
def _run_geo_layer(layer, x, memory, memory_pos):
|
||||
x = x + layer.self_attn(layer.norm1(x))
|
||||
x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory)
|
||||
x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x))))
|
||||
return x
|
||||
|
||||
def _detect(self, features, positions, text_embeddings=None, text_mask=None,
|
||||
points=None, boxes=None):
|
||||
"""Shared detection: geometry encoding, transformer, scoring, segmentation."""
|
||||
B = features[0].shape[0]
|
||||
# Scalp for encoder (use top-level feature), but keep all levels for segmentation head
|
||||
seg_features = features
|
||||
if self.scalp > 0:
|
||||
features = features[:-self.scalp]
|
||||
positions = positions[:-self.scalp]
|
||||
enc_feat, enc_pos = features[-1], positions[-1]
|
||||
_, _, H, W = enc_feat.shape
|
||||
img_flat = enc_feat.flatten(2).permute(0, 2, 1)
|
||||
pos_flat = enc_pos.flatten(2).permute(0, 2, 1)
|
||||
|
||||
has_prompts = text_embeddings is not None or points is not None or boxes is not None
|
||||
if has_prompts:
|
||||
geo_enc = self.geometry_encoder
|
||||
geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat)
|
||||
geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1)))
|
||||
for layer in geo_enc.encode:
|
||||
geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat)
|
||||
geo_cls = geo_enc.encode_norm(geo_cls)
|
||||
if text_embeddings is not None and text_embeddings.shape[0] != B:
|
||||
text_embeddings = text_embeddings.expand(B, -1, -1)
|
||||
if text_mask is not None and text_mask.shape[0] != B:
|
||||
text_mask = text_mask.expand(B, -1)
|
||||
parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None]
|
||||
text_embeddings = torch.cat(parts, dim=1)
|
||||
n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0)
|
||||
if text_mask is not None:
|
||||
text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1)
|
||||
else:
|
||||
text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device)
|
||||
|
||||
memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask)
|
||||
dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W)
|
||||
query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"]
|
||||
|
||||
if text_embeddings is not None:
|
||||
scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask)
|
||||
else:
|
||||
scores = torch.zeros(B, query_out.shape[1], device=query_out.device)
|
||||
|
||||
masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask)
|
||||
return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out
|
||||
|
||||
def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None):
|
||||
features, positions, _, _ = self._get_backbone_features(images)
|
||||
|
||||
if text_embeddings is not None:
|
||||
text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings)
|
||||
if text_mask is not None:
|
||||
text_mask = text_mask.bool()
|
||||
|
||||
boxes_xyxy, scores, masks, dec_out = self._detect(
|
||||
features, positions, text_embeddings, text_mask, points, boxes)
|
||||
|
||||
if orig_size is not None:
|
||||
oh, ow = orig_size
|
||||
boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype)
|
||||
masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False)
|
||||
|
||||
return {
|
||||
"boxes": boxes_xyxy,
|
||||
"scores": scores,
|
||||
"masks": masks,
|
||||
"presence": dec_out.get("presence"),
|
||||
}
|
||||
|
||||
def forward_from_trunk(self, trunk_out, text_embeddings, text_mask):
|
||||
"""Run detection using a pre-computed ViTDet trunk output.
|
||||
|
||||
text_embeddings must already be resized through language_backbone.resizer.
|
||||
Returns dict with boxes (normalized xyxy), scores, masks at detector resolution.
|
||||
"""
|
||||
bb = self.backbone["vision_backbone"]
|
||||
features = [conv(trunk_out) for conv in bb.convs]
|
||||
positions = [cast_to_input(bb.position_encoding(f), f) for f in features]
|
||||
|
||||
if text_mask is not None:
|
||||
text_mask = text_mask.bool()
|
||||
|
||||
boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask)
|
||||
return {"boxes": boxes_xyxy, "scores": scores, "masks": masks}
|
||||
|
||||
|
||||
class SAM3Model(nn.Module):
|
||||
def __init__(self, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
image_model = kwargs.get("image_model", "SAM3")
|
||||
tracker_cls = TRACKER_CLASSES[image_model]
|
||||
self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs)
|
||||
self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs)
|
||||
|
||||
def forward(self, images, **kwargs):
|
||||
return self.detector(images, **kwargs)
|
||||
|
||||
def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None):
|
||||
"""Interactive segmentation using SAM decoder with point/box/mask prompts.
|
||||
|
||||
Args:
|
||||
images: [B, 3, 1008, 1008] preprocessed images
|
||||
point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space
|
||||
box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space
|
||||
mask_inputs: [B, 1, H, W] coarse mask logits to refine
|
||||
Returns:
|
||||
[B, 1, image_size, image_size] high-res mask logits
|
||||
"""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
if bb.multiplex:
|
||||
_, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive")
|
||||
else:
|
||||
_, _, tracker_features, tracker_positions = bb(images, need_tracker=True)
|
||||
if self.detector.scalp > 0:
|
||||
tracker_features = tracker_features[:-self.detector.scalp]
|
||||
tracker_positions = tracker_positions[:-self.detector.scalp]
|
||||
|
||||
high_res = list(tracker_features[:-1])
|
||||
backbone_feat = tracker_features[-1]
|
||||
B, C, H, W = backbone_feat.shape
|
||||
# Add no-memory embedding (init frame path)
|
||||
no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None)
|
||||
if no_mem is None:
|
||||
no_mem = getattr(self.tracker, 'no_mem_embed', None)
|
||||
if no_mem is not None:
|
||||
feat_flat = backbone_feat.flatten(2).permute(0, 2, 1)
|
||||
feat_flat = feat_flat + cast_to_input(no_mem, feat_flat)
|
||||
backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2)
|
||||
|
||||
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
||||
_, high_res_masks, _, _ = self.tracker._forward_sam_heads(
|
||||
backbone_features=backbone_feat,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
box_inputs=box_inputs,
|
||||
high_res_features=high_res,
|
||||
multimask_output=(0 < num_pts <= 1),
|
||||
)
|
||||
return high_res_masks
|
||||
|
||||
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1):
|
||||
"""Track video with optional per-frame text-prompted detection."""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
|
||||
def backbone_fn(frame, frame_idx=None):
|
||||
trunk_out = bb.trunk(frame)
|
||||
if bb.multiplex:
|
||||
_, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True)
|
||||
else:
|
||||
_, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True)
|
||||
return tf, tp, trunk_out
|
||||
|
||||
detect_fn = None
|
||||
if text_prompts:
|
||||
resizer = self.detector.backbone["language_backbone"]["resizer"]
|
||||
resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts]
|
||||
def detect_fn(trunk_out):
|
||||
all_scores, all_masks = [], []
|
||||
for emb, mask in resized:
|
||||
det = self.detector.forward_from_trunk(trunk_out, emb, mask)
|
||||
all_scores.append(det["scores"])
|
||||
all_masks.append(det["masks"])
|
||||
return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)}
|
||||
|
||||
if hasattr(self.tracker, 'track_video_with_detection'):
|
||||
return self.tracker.track_video_with_detection(
|
||||
backbone_fn, images, initial_masks, detect_fn,
|
||||
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
|
||||
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
||||
if initial_masks is None:
|
||||
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
|
||||
425
comfy/ldm/sam3/sam.py
Normal file
425
comfy/ldm/sam3/sam.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ops import cast_to_input
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
|
||||
self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)])
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
|
||||
return torch.sigmoid(x) if self.sigmoid_output else x
|
||||
|
||||
|
||||
class SAMAttention(nn.Module):
|
||||
def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
internal_dim = embedding_dim // downsample_rate
|
||||
kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
||||
self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
return self.out_proj(optimized_attention(q, k, v, self.num_heads))
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
|
||||
self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype))
|
||||
self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, queries, keys, query_pe, key_pe):
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.norm1(self.self_attn(queries, queries, queries))
|
||||
else:
|
||||
q = queries + query_pe
|
||||
queries = self.norm1(queries + self.self_attn(q, q, queries))
|
||||
q, k = queries + query_pe, keys + key_pe
|
||||
queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys))
|
||||
queries = self.norm3(queries + self.mlp(queries))
|
||||
q, k = queries + query_pe, keys + key_pe
|
||||
keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries))
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
|
||||
self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, image_embedding, image_pe, point_embedding):
|
||||
queries, keys = point_embedding, image_embedding
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(queries, keys, point_embedding, image_pe)
|
||||
q, k = queries + point_embedding, keys + image_pe
|
||||
queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys))
|
||||
return queries, keys
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""Fourier feature positional encoding with random gaussian projection."""
|
||||
def __init__(self, num_pos_feats=64, scale=None):
|
||||
super().__init__()
|
||||
self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats))
|
||||
|
||||
def _encode(self, normalized_coords):
|
||||
"""Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32."""
|
||||
orig_dtype = normalized_coords.dtype
|
||||
proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32)
|
||||
projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix
|
||||
return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype)
|
||||
|
||||
def forward(self, size, device=None):
|
||||
h, w = size
|
||||
dev = device if device is not None else self.positional_encoding_gaussian_matrix.device
|
||||
ones = torch.ones((h, w), device=dev, dtype=torch.float32)
|
||||
norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1)
|
||||
return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
def forward_with_coords(self, pixel_coords, image_size):
|
||||
norm = pixel_coords.clone()
|
||||
norm[:, :, 0] /= image_size[1]
|
||||
norm[:, :, 1] /= image_size[0]
|
||||
return self._encode(norm)
|
||||
|
||||
|
||||
# ViTDet backbone + FPN neck
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int):
|
||||
B, H, W, C = x.shape
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw):
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0):
|
||||
"""Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2]."""
|
||||
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
||||
ids = torch.stack([(t % end_x) * scale_pos,
|
||||
torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1)
|
||||
return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0))
|
||||
|
||||
|
||||
class _ViTMLP(nn.Module):
|
||||
def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
hidden = int(dim * mlp_ratio)
|
||||
self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype)
|
||||
self.act = nn.GELU()
|
||||
self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""ViTDet multi-head attention with fused QKV projection."""
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.use_rope = use_rope
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, freqs_cis=None):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
|
||||
if self.use_rope and freqs_cis is not None:
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations)
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, freqs_cis=None):
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
x = x.view(x.shape[0], self.window_size * self.window_size, -1)
|
||||
x = self.attn(x, freqs_cis=freqs_cis)
|
||||
x = x.view(-1, self.window_size, self.window_size, x.shape[-1])
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
else:
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H * W, C)
|
||||
x = self.attn(x, freqs_cis=freqs_cis)
|
||||
x = x.view(B, H, W, C)
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class ViTDet(nn.Module):
|
||||
def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24,
|
||||
global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.global_att_blocks = set(global_att_blocks)
|
||||
|
||||
self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype))
|
||||
|
||||
self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
|
||||
grid_size = img_size // patch_size
|
||||
pretrain_grid = pretrain_img_size // patch_size
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
is_global = i in self.global_att_blocks
|
||||
self.blocks.append(Block(
|
||||
embed_dim, num_heads, mlp_ratio, qkv_bias,
|
||||
window_size=0 if is_global else window_size,
|
||||
use_rope=use_rope,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
))
|
||||
|
||||
if use_rope:
|
||||
rope_scale = pretrain_grid / grid_size
|
||||
self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False)
|
||||
self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False)
|
||||
else:
|
||||
self.freqs_cis = None
|
||||
self.freqs_cis_window = None
|
||||
|
||||
def _get_pos_embed(self, num_tokens):
|
||||
pos = self.pos_embed
|
||||
if pos.shape[1] == num_tokens:
|
||||
return pos
|
||||
cls_pos = pos[:, :1]
|
||||
spatial_pos = pos[:, 1:]
|
||||
old_size = int(math.sqrt(spatial_pos.shape[1]))
|
||||
new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size
|
||||
spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
|
||||
tiles_h = new_size // old_size + 1
|
||||
tiles_w = new_size // old_size + 1
|
||||
tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size]
|
||||
tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
|
||||
return torch.cat([cls_pos, tiled], dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
B, C, Hp, Wp = x.shape
|
||||
x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C)
|
||||
|
||||
pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x)
|
||||
x = x + pos[:, 1:Hp * Wp + 1]
|
||||
|
||||
x = x.view(B, Hp, Wp, C)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
freqs_cis_global = self.freqs_cis
|
||||
freqs_cis_win = self.freqs_cis_window
|
||||
if freqs_cis_global is not None:
|
||||
freqs_cis_global = cast_to_input(freqs_cis_global, x)
|
||||
if freqs_cis_win is not None:
|
||||
freqs_cis_win = cast_to_input(freqs_cis_win, x)
|
||||
|
||||
for block in self.blocks:
|
||||
fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global
|
||||
x = block(x, freqs_cis=fc)
|
||||
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class FPNScaleConv(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
if scale == 4.0:
|
||||
self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
|
||||
self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype)
|
||||
proj_in = in_dim // 4
|
||||
elif scale == 2.0:
|
||||
self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
|
||||
proj_in = in_dim // 2
|
||||
elif scale == 1.0:
|
||||
proj_in = in_dim
|
||||
elif scale == 0.5:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
proj_in = in_dim
|
||||
self.scale = scale
|
||||
self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype)
|
||||
self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale == 4.0:
|
||||
x = F.gelu(self.dconv_2x2_0(x))
|
||||
x = self.dconv_2x2_1(x)
|
||||
elif self.scale == 2.0:
|
||||
x = self.dconv_2x2(x)
|
||||
elif self.scale == 0.5:
|
||||
x = self.pool(x)
|
||||
x = self.conv_1x1(x)
|
||||
x = self.conv_3x3(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""2D sinusoidal position encoding (DETR-style) with result caching."""
|
||||
def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None):
|
||||
super().__init__()
|
||||
assert num_pos_feats % 2 == 0
|
||||
self.half_dim = num_pos_feats // 2
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
self.scale = scale if scale is not None else 2 * math.pi
|
||||
self._cache = {}
|
||||
|
||||
def _sincos(self, vals):
|
||||
"""Encode 1D values to interleaved sin/cos features."""
|
||||
freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim)
|
||||
raw = vals[..., None] * self.scale / freqs
|
||||
return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2)
|
||||
|
||||
def _encode_xy(self, x, y):
|
||||
"""Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim]."""
|
||||
dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim)
|
||||
pos_x = x[:, None] * self.scale / dim_t
|
||||
pos_y = y[:, None] * self.scale / dim_t
|
||||
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
|
||||
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
|
||||
return pos_x, pos_y
|
||||
|
||||
def encode_boxes(self, cx, cy, w, h):
|
||||
"""Encode box center + size to [N, d_model+2] features."""
|
||||
pos_x, pos_y = self._encode_xy(cx, cy)
|
||||
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
key = (H, W, x.device)
|
||||
if key not in self._cache:
|
||||
gy = torch.arange(H, dtype=torch.float32, device=x.device)
|
||||
gx = torch.arange(W, dtype=torch.float32, device=x.device)
|
||||
if self.normalize:
|
||||
gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6)
|
||||
yy, xx = torch.meshgrid(gy, gx, indexing="ij")
|
||||
self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0)
|
||||
return self._cache[key].expand(B, -1, -1, -1)
|
||||
|
||||
|
||||
class SAM3VisionBackbone(nn.Module):
|
||||
def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs)
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
|
||||
self.multiplex = multiplex
|
||||
|
||||
fpn_args = dict(device=device, dtype=dtype, operations=operations)
|
||||
if multiplex:
|
||||
scales = [4.0, 2.0, 1.0]
|
||||
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
else:
|
||||
scales = [4.0, 2.0, 1.0, 0.5]
|
||||
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
|
||||
|
||||
def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False):
|
||||
backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images)
|
||||
|
||||
if tracker_only:
|
||||
# Skip detector FPN when only tracker features are needed (video tracking)
|
||||
if self.multiplex:
|
||||
tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs
|
||||
else:
|
||||
tracker_convs = self.sam2_convs
|
||||
tracker_features = [conv(backbone_out) for conv in tracker_convs]
|
||||
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
|
||||
return None, None, tracker_features, tracker_positions
|
||||
|
||||
features = [conv(backbone_out) for conv in self.convs]
|
||||
positions = [cast_to_input(self.position_encoding(f), f) for f in features]
|
||||
|
||||
if self.multiplex:
|
||||
if tracker_mode == "propagation":
|
||||
tracker_convs = self.propagation_convs
|
||||
elif tracker_mode == "interactive":
|
||||
tracker_convs = self.interactive_convs
|
||||
else:
|
||||
return features, positions, None, None
|
||||
elif need_tracker:
|
||||
tracker_convs = self.sam2_convs
|
||||
else:
|
||||
return features, positions, None, None
|
||||
|
||||
tracker_features = [conv(backbone_out) for conv in tracker_convs]
|
||||
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
|
||||
return features, positions, tracker_features, tracker_positions
|
||||
1785
comfy/ldm/sam3/tracker.py
Normal file
1785
comfy/ldm/sam3/tracker.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -54,6 +54,7 @@ import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -578,8 +579,8 @@ class Stable_Zero123(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
|
||||
self.cc_projection.weight.copy_(cc_projection_weight)
|
||||
self.cc_projection.bias.copy_(cc_projection_bias)
|
||||
self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone())
|
||||
self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone())
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@@ -1974,3 +1975,7 @@ class ErnieImage(BaseModel):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class SAM3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)
|
||||
|
||||
@@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "ernie"
|
||||
return dit_config
|
||||
|
||||
if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1
|
||||
if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "SAM3"
|
||||
if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys:
|
||||
dit_config["image_model"] = "SAM31"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
# SAM3: detector.* and tracker.* at top level, no common prefix
|
||||
if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict):
|
||||
return ""
|
||||
|
||||
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||
"model.model.", #audio models
|
||||
"net.", #cosmos
|
||||
|
||||
@@ -1801,7 +1801,7 @@ def debug_memory_summary():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
class InterruptProcessingException(BaseException):
|
||||
pass
|
||||
|
||||
interrupt_processing_mutex = threading.RLock()
|
||||
|
||||
@@ -685,9 +685,9 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if key not in self.patches:
|
||||
if key not in self.patches and not force_cast:
|
||||
return weight
|
||||
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
@@ -695,7 +695,7 @@ class ModelPatcher:
|
||||
if key not in self.backup and not return_weight:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||
else:
|
||||
@@ -703,8 +703,9 @@ class ModelPatcher:
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight
|
||||
if set_func is None:
|
||||
if key in self.patches:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
if return_weight:
|
||||
return out_weight
|
||||
@@ -1584,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is not None:
|
||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||
@@ -1609,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
for param in params:
|
||||
if param not in ("weight", "bias"):
|
||||
force_load_param(self, param, device_to)
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
key = key_param_name_to_key(n, param)
|
||||
|
||||
19
comfy/sd.py
19
comfy/sd.py
@@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.lightricks.vae.audio_vae
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
@@ -805,6 +806,24 @@ class VAE:
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."})
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
|
||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = self.first_stage_model.latent_channels
|
||||
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
|
||||
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
|
||||
@@ -1781,6 +1781,57 @@ class ErnieImage(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||
class SAM3(supported_models_base.BASE):
|
||||
unet_config = {"image_model": "SAM3"}
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
text_encoder_key_prefix = ["detector.backbone.language_backbone."]
|
||||
unet_extra_prefix = ""
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
clip_keys = getattr(self, "_clip_stash", {})
|
||||
clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True)
|
||||
clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.")
|
||||
return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")}
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k}
|
||||
# SAM3.1: remap tracker.model.* -> tracker.*
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("tracker.model."):
|
||||
state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k)
|
||||
# SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
|
||||
for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]:
|
||||
state_dict.pop(k)
|
||||
# Split fused QKV projections
|
||||
for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]:
|
||||
t = state_dict.pop(k)
|
||||
base, suffix = k.rsplit(".in_proj_", 1)
|
||||
s = ".weight" if suffix == "weight" else ".bias"
|
||||
d = t.shape[0] // 3
|
||||
state_dict[base + ".q_proj" + s] = t[:d]
|
||||
state_dict[base + ".k_proj" + s] = t[d:2*d]
|
||||
state_dict[base + ".v_proj" + s] = t[2*d:]
|
||||
# Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
|
||||
for k in list(state_dict.keys()):
|
||||
if "sam_mask_decoder.transformer." not in k:
|
||||
continue
|
||||
new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.")
|
||||
if new_k != k:
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
return state_dict
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SAM3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
import comfy.text_encoders.sam3_clip
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper)
|
||||
|
||||
|
||||
class SAM31(SAM3):
|
||||
unet_config = {"image_model": "SAM31"}
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
97
comfy/text_encoders/sam3_clip.py
Normal file
97
comfy/text_encoders/sam3_clip.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import re
|
||||
from comfy import sd1_clip
|
||||
|
||||
SAM3_CLIP_CONFIG = {
|
||||
"architectures": ["CLIPTextModel"],
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 4096,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"max_position_embeddings": 32,
|
||||
"projection_dim": 512,
|
||||
"vocab_size": 49408,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"eos_token_id": 49407,
|
||||
}
|
||||
|
||||
|
||||
class SAM3ClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options)
|
||||
|
||||
|
||||
class SAM3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data)
|
||||
self.disable_weights = True
|
||||
|
||||
|
||||
def _parse_prompts(text):
|
||||
"""Split comma-separated prompts with optional :N max detections per category"""
|
||||
text = text.replace("(", "").replace(")", "")
|
||||
parts = [p.strip() for p in text.split(",") if p.strip()]
|
||||
result = []
|
||||
for part in parts:
|
||||
m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part)
|
||||
if m:
|
||||
text_part = m.group(1).strip()
|
||||
val = m.group(2)
|
||||
max_det = max(1, round(float(val)))
|
||||
result.append((text_part, max_det))
|
||||
else:
|
||||
result.append((part, 1))
|
||||
return result
|
||||
|
||||
|
||||
class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip")
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
parsed = _parse_prompts(text)
|
||||
if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1):
|
||||
return super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
# Tokenize each prompt part separately, store per-part batches and metadata
|
||||
inner = getattr(self, self.clip)
|
||||
per_prompt = []
|
||||
for prompt_text, max_det in parsed:
|
||||
batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs)
|
||||
per_prompt.append((batches, max_det))
|
||||
# Main output uses first prompt's tokens (for compatibility)
|
||||
out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt}
|
||||
return out
|
||||
|
||||
|
||||
class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip")
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
per_prompt = token_weight_pairs.pop("sam3_per_prompt", None)
|
||||
if per_prompt is None:
|
||||
return super().encode_token_weights(token_weight_pairs)
|
||||
|
||||
# Encode each prompt separately, pack into extra dict
|
||||
inner = getattr(self, self.clip)
|
||||
multi_cond = []
|
||||
first_pooled = None
|
||||
for batches, max_det in per_prompt:
|
||||
out = inner.encode_token_weights(batches)
|
||||
cond, pooled = out[0], out[1]
|
||||
extra = out[2] if len(out) > 2 else {}
|
||||
if first_pooled is None:
|
||||
first_pooled = pooled
|
||||
multi_cond.append({
|
||||
"cond": cond,
|
||||
"attention_mask": extra.get("attention_mask"),
|
||||
"max_detections": max_det,
|
||||
})
|
||||
|
||||
# Return first prompt as main (for non-SAM3 consumers), all prompts in metadata
|
||||
main = multi_cond[0]
|
||||
main_extra = {}
|
||||
if main["attention_mask"] is not None:
|
||||
main_extra["attention_mask"] = main["attention_mask"]
|
||||
main_extra["sam3_multi_cond"] = multi_cond
|
||||
return (main["cond"], first_pooled, main_extra)
|
||||
@@ -15,7 +15,6 @@ from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
from comfy.cli_args import args
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class ComfyAPI_latest(ComfyAPIBase):
|
||||
@@ -26,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
self.environment = self.Environment()
|
||||
self.caching = self.Caching()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
@@ -87,27 +85,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
class Environment(ProxiedSingleton):
|
||||
"""
|
||||
Query the current execution environment.
|
||||
|
||||
Managed deployments set the ``COMFY_EXECUTION_ENVIRONMENT`` env var
|
||||
so custom nodes can adapt their behaviour at runtime.
|
||||
|
||||
Example::
|
||||
|
||||
from comfy_api.latest import api
|
||||
|
||||
env = api.environment.get() # "local" | "cloud" | "remote"
|
||||
"""
|
||||
|
||||
_VALID = {"local", "cloud", "remote"}
|
||||
|
||||
async def get(self) -> str:
|
||||
"""Return the execution environment: ``"local"``, ``"cloud"``, or ``"remote"``."""
|
||||
value = os.environ.get("COMFY_EXECUTION_ENVIRONMENT", "local").lower().strip()
|
||||
return value if value in self._VALID else "local"
|
||||
|
||||
class Caching(ProxiedSingleton):
|
||||
"""
|
||||
External cache provider API for sharing cached node outputs
|
||||
|
||||
@@ -122,6 +122,41 @@ class TaskStatusResponse(BaseModel):
|
||||
usage: TaskStatusUsage | None = Field(None)
|
||||
|
||||
|
||||
class GetAssetResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
name: str | None = Field(None)
|
||||
url: str | None = Field(None)
|
||||
asset_type: str = Field(...)
|
||||
group_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
|
||||
|
||||
class SeedanceCreateVisualValidateSessionResponse(BaseModel):
|
||||
session_id: str = Field(...)
|
||||
h5_link: str = Field(...)
|
||||
|
||||
|
||||
class SeedanceGetVisualValidateSessionResponse(BaseModel):
|
||||
session_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
group_id: str | None = Field(None)
|
||||
error_code: str | None = Field(None)
|
||||
error_message: str | None = Field(None)
|
||||
|
||||
|
||||
class SeedanceCreateAssetRequest(BaseModel):
|
||||
group_id: str = Field(...)
|
||||
url: str = Field(...)
|
||||
asset_type: str = Field(...)
|
||||
name: str | None = Field(None, max_length=64)
|
||||
project_name: str | None = Field(None)
|
||||
|
||||
|
||||
class SeedanceCreateAssetResponse(BaseModel):
|
||||
asset_id: str = Field(...)
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||
@@ -158,10 +193,17 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# Seedance 2.0 reference video pixel count limits per model.
|
||||
# Seedance 2.0 reference video pixel count limits per model and output resolution.
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
|
||||
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
|
||||
"dreamina-seedance-2-0-260128": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
"1080p": {"min": 409_600, "max": 2_073_600},
|
||||
},
|
||||
"dreamina-seedance-2-0-fast-260128": {
|
||||
"480p": {"min": 409_600, "max": 927_408},
|
||||
"720p": {"min": 409_600, "max": 927_408},
|
||||
},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
@@ -11,9 +12,14 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
GetAssetResponse,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedance2TaskCreationRequest,
|
||||
SeedanceCreateAssetRequest,
|
||||
SeedanceCreateAssetResponse,
|
||||
SeedanceCreateVisualValidateSessionResponse,
|
||||
SeedanceGetVisualValidateSessionResponse,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskAudioContent,
|
||||
@@ -35,6 +41,7 @@ from comfy_api_nodes.util import (
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
resize_video_to_pixel_budget,
|
||||
sync_op,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
@@ -43,10 +50,16 @@ from comfy_api_nodes.util import (
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
|
||||
|
||||
_VERIFICATION_POLL_TIMEOUT_SEC = 120
|
||||
_VERIFICATION_POLL_INTERVAL_SEC = 3
|
||||
|
||||
SEEDREAM_MODELS = {
|
||||
"seedream 5.0 lite": "seedream-5-0-260128",
|
||||
"seedream-4-5-251128": "seedream-4-5-251128",
|
||||
@@ -69,9 +82,12 @@ DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-2504
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
|
||||
"""Validate reference video pixel count against Seedance 2.0 model limits."""
|
||||
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||
def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None:
|
||||
"""Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution."""
|
||||
model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||
if not model_limits:
|
||||
return
|
||||
limits = model_limits.get(resolution)
|
||||
if not limits:
|
||||
return
|
||||
try:
|
||||
@@ -92,6 +108,169 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) ->
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_reference_assets(
|
||||
cls: type[IO.ComfyNode],
|
||||
asset_ids: list[str],
|
||||
) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
|
||||
"""Look up each asset, validate Active status, group by asset_type.
|
||||
|
||||
Returns (image_assets, video_assets, audio_assets), each mapping asset_id -> "asset://<asset_id>".
|
||||
"""
|
||||
image_assets: dict[str, str] = {}
|
||||
video_assets: dict[str, str] = {}
|
||||
audio_assets: dict[str, str] = {}
|
||||
for i, raw_id in enumerate(asset_ids, 1):
|
||||
asset_id = (raw_id or "").strip()
|
||||
if not asset_id:
|
||||
continue
|
||||
result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"),
|
||||
response_model=GetAssetResponse,
|
||||
)
|
||||
if result.status != "Active":
|
||||
extra = f" {result.error.code}: {result.error.message}" if result.error else ""
|
||||
raise ValueError(f"Reference asset {i} (Id={asset_id}) is not Active (Status={result.status}).{extra}")
|
||||
asset_uri = f"asset://{asset_id}"
|
||||
if result.asset_type == "Image":
|
||||
image_assets[asset_id] = asset_uri
|
||||
elif result.asset_type == "Video":
|
||||
video_assets[asset_id] = asset_uri
|
||||
elif result.asset_type == "Audio":
|
||||
audio_assets[asset_id] = asset_uri
|
||||
return image_assets, video_assets, audio_assets
|
||||
|
||||
|
||||
_ASSET_REF_RE = re.compile(r"\basset ?(\d{1,2})\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def _build_asset_labels(
|
||||
reference_assets: dict[str, str],
|
||||
image_asset_uris: dict[str, str],
|
||||
video_asset_uris: dict[str, str],
|
||||
audio_asset_uris: dict[str, str],
|
||||
n_reference_images: int,
|
||||
n_reference_videos: int,
|
||||
n_reference_audios: int,
|
||||
) -> dict[int, str]:
|
||||
"""Map asset slot number (from 'asset_N' keys) to its positional label.
|
||||
|
||||
Asset entries are appended to `content` after the reference_images/videos/audios,
|
||||
so their 1-indexed labels continue from the count of existing same-type refs:
|
||||
one reference_images entry + one Image-type asset -> asset labelled "Image 2".
|
||||
"""
|
||||
image_n = n_reference_images
|
||||
video_n = n_reference_videos
|
||||
audio_n = n_reference_audios
|
||||
labels: dict[int, str] = {}
|
||||
for slot_key, raw_id in reference_assets.items():
|
||||
asset_id = (raw_id or "").strip()
|
||||
if not asset_id:
|
||||
continue
|
||||
try:
|
||||
slot_num = int(slot_key.rsplit("_", 1)[-1])
|
||||
except ValueError:
|
||||
continue
|
||||
if asset_id in image_asset_uris:
|
||||
image_n += 1
|
||||
labels[slot_num] = f"Image {image_n}"
|
||||
elif asset_id in video_asset_uris:
|
||||
video_n += 1
|
||||
labels[slot_num] = f"Video {video_n}"
|
||||
elif asset_id in audio_asset_uris:
|
||||
audio_n += 1
|
||||
labels[slot_num] = f"Audio {audio_n}"
|
||||
return labels
|
||||
|
||||
|
||||
def _rewrite_asset_refs(prompt: str, labels: dict[int, str]) -> str:
|
||||
"""Case-insensitively replace 'assetNN' (1-2 digit) tokens with their labels."""
|
||||
if not labels:
|
||||
return prompt
|
||||
|
||||
def _sub(m: "re.Match[str]") -> str:
|
||||
return labels.get(int(m.group(1)), m.group(0))
|
||||
|
||||
return _ASSET_REF_RE.sub(_sub, prompt)
|
||||
|
||||
|
||||
async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str:
|
||||
session = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"),
|
||||
response_model=SeedanceCreateVisualValidateSessionResponse,
|
||||
)
|
||||
logger.warning("Seedance authentication required. Open link: %s", session.h5_link)
|
||||
|
||||
h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}"
|
||||
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"),
|
||||
response_model=SeedanceGetVisualValidateSessionResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
poll_interval=_VERIFICATION_POLL_INTERVAL_SEC,
|
||||
max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1,
|
||||
estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1,
|
||||
extra_text=h5_text,
|
||||
)
|
||||
|
||||
if not result.group_id:
|
||||
raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id")
|
||||
|
||||
logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id
|
||||
)
|
||||
return result.group_id
|
||||
|
||||
|
||||
async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str:
|
||||
if group_id and group_id.strip():
|
||||
return group_id.strip()
|
||||
return await _obtain_group_id_via_h5_auth(cls)
|
||||
|
||||
|
||||
async def _create_seedance_asset(
|
||||
cls: type[IO.ComfyNode],
|
||||
*,
|
||||
group_id: str,
|
||||
url: str,
|
||||
name: str,
|
||||
asset_type: str,
|
||||
) -> str:
|
||||
req = SeedanceCreateAssetRequest(
|
||||
group_id=group_id,
|
||||
url=url,
|
||||
asset_type=asset_type,
|
||||
name=name or None,
|
||||
)
|
||||
result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/assets", method="POST"),
|
||||
response_model=SeedanceCreateAssetResponse,
|
||||
data=req,
|
||||
)
|
||||
return result.asset_id
|
||||
|
||||
|
||||
async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_id: str) -> GetAssetResponse:
|
||||
"""Poll the newly created asset until its status becomes Active."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"),
|
||||
response_model=GetAssetResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
completed_statuses=["Active"],
|
||||
failed_statuses=["Failed"],
|
||||
poll_interval=5,
|
||||
max_poll_attempts=1200,
|
||||
extra_text=f"Waiting for asset pre-processing...\n\nasset_id: {asset_id}\n\ngroup_id: {group_id}",
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
@@ -1224,12 +1403,27 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="First frame image for the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Image.Input(
|
||||
"last_frame",
|
||||
tooltip="Last frame image for the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"first_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the first frame. "
|
||||
"Mutually exclusive with the first_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"last_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the last frame. "
|
||||
"Mutually exclusive with the last_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
@@ -1282,24 +1476,54 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
first_frame: Input.Image,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
first_frame: Input.Image | None = None,
|
||||
last_frame: Input.Image | None = None,
|
||||
first_frame_asset_id: str = "",
|
||||
last_frame_asset_id: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
|
||||
first_frame_asset_id = first_frame_asset_id.strip()
|
||||
last_frame_asset_id = last_frame_asset_id.strip()
|
||||
|
||||
if first_frame is not None and first_frame_asset_id:
|
||||
raise ValueError("Provide only one of first_frame or first_frame_asset_id, not both.")
|
||||
if first_frame is None and not first_frame_asset_id:
|
||||
raise ValueError("Either first_frame or first_frame_asset_id is required.")
|
||||
if last_frame is not None and last_frame_asset_id:
|
||||
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
||||
|
||||
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
||||
image_assets: dict[str, str] = {}
|
||||
if asset_ids_to_resolve:
|
||||
image_assets, _, _ = await _resolve_reference_assets(cls, asset_ids_to_resolve)
|
||||
for aid in asset_ids_to_resolve:
|
||||
if aid not in image_assets:
|
||||
raise ValueError(f"Asset {aid} is not an Image asset.")
|
||||
|
||||
if first_frame_asset_id:
|
||||
first_frame_url = image_assets[first_frame_asset_id]
|
||||
else:
|
||||
first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||
),
|
||||
image_url=TaskImageContentUrl(url=first_frame_url),
|
||||
role="first_frame",
|
||||
),
|
||||
]
|
||||
if last_frame is not None:
|
||||
if last_frame_asset_id:
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(url=image_assets[last_frame_asset_id]),
|
||||
role="last_frame",
|
||||
),
|
||||
)
|
||||
elif last_frame is not None:
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
@@ -1373,6 +1597,32 @@ def _seedance2_reference_inputs(resolutions: list[str]):
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
advanced=True,
|
||||
optional=True,
|
||||
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_assets",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.String.Input("reference_asset"),
|
||||
names=[
|
||||
"asset_1",
|
||||
"asset_2",
|
||||
"asset_3",
|
||||
"asset_4",
|
||||
"asset_5",
|
||||
"asset_6",
|
||||
"asset_7",
|
||||
"asset_8",
|
||||
"asset_9",
|
||||
],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -1474,16 +1724,47 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
reference_images = model.get("reference_images", {})
|
||||
reference_videos = model.get("reference_videos", {})
|
||||
reference_audios = model.get("reference_audios", {})
|
||||
reference_assets = model.get("reference_assets", {})
|
||||
|
||||
if not reference_images and not reference_videos:
|
||||
raise ValueError("At least one reference image or video is required.")
|
||||
reference_image_assets, reference_video_assets, reference_audio_assets = await _resolve_reference_assets(
|
||||
cls, list(reference_assets.values())
|
||||
)
|
||||
|
||||
if not reference_images and not reference_videos and not reference_image_assets and not reference_video_assets:
|
||||
raise ValueError("At least one reference image or video or asset is required.")
|
||||
|
||||
total_images = len(reference_images) + len(reference_image_assets)
|
||||
if total_images > 9:
|
||||
raise ValueError(
|
||||
f"Too many reference images: {total_images} "
|
||||
f"(images={len(reference_images)}, image assets={len(reference_image_assets)}). Maximum is 9."
|
||||
)
|
||||
total_videos = len(reference_videos) + len(reference_video_assets)
|
||||
if total_videos > 3:
|
||||
raise ValueError(
|
||||
f"Too many reference videos: {total_videos} "
|
||||
f"(videos={len(reference_videos)}, video assets={len(reference_video_assets)}). Maximum is 3."
|
||||
)
|
||||
total_audios = len(reference_audios) + len(reference_audio_assets)
|
||||
if total_audios > 3:
|
||||
raise ValueError(
|
||||
f"Too many reference audios: {total_audios} "
|
||||
f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3."
|
||||
)
|
||||
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
has_video_input = len(reference_videos) > 0
|
||||
has_video_input = total_videos > 0
|
||||
|
||||
if model.get("auto_downscale") and reference_videos:
|
||||
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
|
||||
if max_px:
|
||||
for key in reference_videos:
|
||||
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
|
||||
|
||||
total_video_duration = 0.0
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
video = reference_videos[key]
|
||||
_validate_ref_video_pixels(video, model_id, i)
|
||||
_validate_ref_video_pixels(video, model_id, model["resolution"], i)
|
||||
try:
|
||||
dur = video.get_duration()
|
||||
if dur < 1.8:
|
||||
@@ -1506,8 +1787,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
if total_audio_duration > 15.1:
|
||||
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||
|
||||
asset_labels = _build_asset_labels(
|
||||
reference_assets,
|
||||
reference_image_assets,
|
||||
reference_video_assets,
|
||||
reference_audio_assets,
|
||||
len(reference_images),
|
||||
len(reference_videos),
|
||||
len(reference_audios),
|
||||
)
|
||||
prompt_text = _rewrite_asset_refs(model["prompt"], asset_labels)
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
TaskTextContent(text=prompt_text),
|
||||
]
|
||||
for i, key in enumerate(reference_images, 1):
|
||||
content.append(
|
||||
@@ -1548,6 +1840,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
),
|
||||
),
|
||||
)
|
||||
for url in reference_image_assets.values():
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(url=url),
|
||||
role="reference_image",
|
||||
),
|
||||
)
|
||||
for url in reference_video_assets.values():
|
||||
content.append(
|
||||
TaskVideoContent(video_url=TaskVideoContentUrl(url=url)),
|
||||
)
|
||||
for url in reference_audio_assets.values():
|
||||
content.append(
|
||||
TaskAudioContent(audio_url=TaskAudioContentUrl(url=url)),
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
@@ -1602,6 +1909,156 @@ async def process_video_task(
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
class ByteDanceCreateImageAsset(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateImageAsset",
|
||||
display_name="ByteDance Create Image Asset",
|
||||
category="api node/image/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
"H5 authentication flow to create a new group before adding the asset."
|
||||
),
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Image to register as a personal asset."),
|
||||
IO.String.Input(
|
||||
"group_id",
|
||||
default="",
|
||||
tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the "
|
||||
"same person. Leave empty to run real-person authentication in the browser and create a new group.",
|
||||
),
|
||||
# IO.String.Input(
|
||||
# "name",
|
||||
# default="",
|
||||
# tooltip="Asset name (up to 64 characters).",
|
||||
# ),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="asset_id"),
|
||||
IO.String.Output(display_name="group_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
# is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
group_id: str = "",
|
||||
# name: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
# if len(name) > 64:
|
||||
# raise ValueError("Name of asset can not be greater then 64 symbols")
|
||||
validate_image_dimensions(image, min_width=300, max_width=6000, min_height=300, max_height=6000)
|
||||
validate_image_aspect_ratio(image, min_ratio=(0.4, 1), max_ratio=(2.5, 1))
|
||||
resolved_group = await _resolve_group_id(cls, group_id)
|
||||
asset_id = await _create_seedance_asset(
|
||||
cls,
|
||||
group_id=resolved_group,
|
||||
url=await upload_image_to_comfyapi(cls, image),
|
||||
name="",
|
||||
asset_type="Image",
|
||||
)
|
||||
await _wait_for_asset_active(cls, asset_id, resolved_group)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n"
|
||||
f"group_id: {resolved_group}",
|
||||
cls.hidden.unique_id,
|
||||
)
|
||||
return IO.NodeOutput(asset_id, resolved_group)
|
||||
|
||||
|
||||
class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateVideoAsset",
|
||||
display_name="ByteDance Create Video Asset",
|
||||
category="api node/video/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
"H5 authentication flow to create a new group before adding the asset."
|
||||
),
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Video to register as a personal asset."),
|
||||
IO.String.Input(
|
||||
"group_id",
|
||||
default="",
|
||||
tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the "
|
||||
"same person. Leave empty to run real-person authentication in the browser and create a new group.",
|
||||
),
|
||||
# IO.String.Input(
|
||||
# "name",
|
||||
# default="",
|
||||
# tooltip="Asset name (up to 64 characters).",
|
||||
# ),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="asset_id"),
|
||||
IO.String.Output(display_name="group_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
# is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
group_id: str = "",
|
||||
# name: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
# if len(name) > 64:
|
||||
# raise ValueError("Name of asset can not be greater then 64 symbols")
|
||||
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||
validate_video_dimensions(video, min_width=300, max_width=6000, min_height=300, max_height=6000)
|
||||
|
||||
w, h = video.get_dimensions()
|
||||
if h > 0:
|
||||
ratio = w / h
|
||||
if not (0.4 <= ratio <= 2.5):
|
||||
raise ValueError(f"Asset video aspect ratio (W/H) must be in [0.4, 2.5], got {ratio:.3f} ({w}x{h}).")
|
||||
pixels = w * h
|
||||
if not (409_600 <= pixels <= 927_408):
|
||||
raise ValueError(
|
||||
f"Asset video total pixels (W×H) must be in [409600, 927408], " f"got {pixels:,} ({w}x{h})."
|
||||
)
|
||||
|
||||
fps = float(video.get_frame_rate())
|
||||
if not (24 <= fps <= 60):
|
||||
raise ValueError(f"Asset video FPS must be in [24, 60], got {fps:.2f}.")
|
||||
|
||||
resolved_group = await _resolve_group_id(cls, group_id)
|
||||
asset_id = await _create_seedance_asset(
|
||||
cls,
|
||||
group_id=resolved_group,
|
||||
url=await upload_video_to_comfyapi(cls, video),
|
||||
name="",
|
||||
asset_type="Video",
|
||||
)
|
||||
await _wait_for_asset_active(cls, asset_id, resolved_group)
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n"
|
||||
f"group_id: {resolved_group}",
|
||||
cls.hidden.unique_id,
|
||||
)
|
||||
return IO.NodeOutput(asset_id, resolved_group)
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -1615,6 +2072,8 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDance2TextToVideoNode,
|
||||
ByteDance2FirstLastFrameNode,
|
||||
ByteDance2ReferenceNode,
|
||||
ByteDanceCreateImageAsset,
|
||||
ByteDanceCreateVideoAsset,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -357,13 +357,17 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
|
||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
||||
|
||||
|
||||
def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None:
|
||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
|
||||
|
||||
|
||||
class OpenAIGPTImage1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 1.5",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
@@ -401,7 +405,17 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
options=[
|
||||
"auto",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"2048x2048",
|
||||
"2048x1152",
|
||||
"1152x2048",
|
||||
"3840x2160",
|
||||
"2160x3840",
|
||||
],
|
||||
tooltip="Image size",
|
||||
optional=True,
|
||||
),
|
||||
@@ -427,8 +441,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-1", "gpt-image-1.5"],
|
||||
default="gpt-image-1.5",
|
||||
options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"],
|
||||
default="gpt-image-2",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@@ -442,23 +456,36 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
|
||||
expr="""
|
||||
(
|
||||
$ranges := {
|
||||
"gpt-image-1": {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.046, 0.07],
|
||||
"high": [0.167, 0.3]
|
||||
"medium": [0.042, 0.07],
|
||||
"high": [0.167, 0.25]
|
||||
},
|
||||
"gpt-image-1.5": {
|
||||
"low": [0.009, 0.02],
|
||||
"medium": [0.034, 0.062],
|
||||
"high": [0.133, 0.22]
|
||||
},
|
||||
"gpt-image-2": {
|
||||
"low": [0.0048, 0.012],
|
||||
"medium": [0.041, 0.112],
|
||||
"high": [0.165, 0.43]
|
||||
}
|
||||
};
|
||||
$range := $lookup($ranges, widgets.quality);
|
||||
$n := widgets.n;
|
||||
$range := $lookup($lookup($ranges, widgets.model), widgets.quality);
|
||||
$nRaw := widgets.n;
|
||||
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
|
||||
($n = 1)
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]}
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
|
||||
: {
|
||||
"type":"range_usd",
|
||||
"min_usd": $range[0],
|
||||
"max_usd": $range[1],
|
||||
"format": { "suffix": " x " & $string($n) & "/Run" }
|
||||
"min_usd": $range[0] * $n,
|
||||
"max_usd": $range[1] * $n,
|
||||
"format": { "suffix": "/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
@@ -483,10 +510,18 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
if mask is not None and image is None:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if model in ("gpt-image-1", "gpt-image-1.5"):
|
||||
if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"):
|
||||
raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model")
|
||||
|
||||
if model == "gpt-image-1":
|
||||
price_extractor = calculate_tokens_price_image_1
|
||||
elif model == "gpt-image-1.5":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
elif model == "gpt-image-2":
|
||||
price_extractor = calculate_tokens_price_image_2_0
|
||||
if background == "transparent":
|
||||
raise ValueError("Transparent background is not supported for GPT Image 2 model")
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
|
||||
|
||||
@@ -24,8 +24,9 @@ from comfy_api_nodes.util import (
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
MODELS_MAP = {
|
||||
"veo-2.0-generate-001": "veo-2.0-generate-001",
|
||||
"veo-3.1-generate": "veo-3.1-generate-preview",
|
||||
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
|
||||
"veo-3.1-generate": "veo-3.1-generate-001",
|
||||
"veo-3.1-fast-generate": "veo-3.1-fast-generate-001",
|
||||
"veo-3.1-lite": "veo-3.1-lite-generate-001",
|
||||
"veo-3.0-generate-001": "veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
||||
}
|
||||
@@ -247,17 +248,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo 3 API.
|
||||
|
||||
Supported models:
|
||||
- veo-3.0-generate-001
|
||||
- veo-3.0-fast-generate-001
|
||||
|
||||
This node extends the base Veo node with Veo 3 specific features including
|
||||
audio generation and fixed 8-second duration.
|
||||
"""
|
||||
class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
"""Generates videos from text prompts using Google's Veo 3 API."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -279,6 +271,13 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p", "4k"],
|
||||
default="720p",
|
||||
tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
@@ -289,11 +288,11 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
IO.Int.Input(
|
||||
"duration_seconds",
|
||||
default=8,
|
||||
min=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
step=2,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
@@ -332,10 +331,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
options=[
|
||||
"veo-3.1-generate",
|
||||
"veo-3.1-fast-generate",
|
||||
"veo-3.1-lite",
|
||||
"veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001",
|
||||
],
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
@@ -356,21 +355,111 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$a := widgets.generate_audio;
|
||||
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
|
||||
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
|
||||
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
|
||||
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
|
||||
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
|
||||
$seconds := widgets.duration_seconds;
|
||||
$pps :=
|
||||
$contains($m, "lite")
|
||||
? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03))
|
||||
: $contains($m, "3.1-fast")
|
||||
? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08))
|
||||
: $contains($m, "3.1-generate")
|
||||
? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20))
|
||||
: $contains($m, "3.0-fast")
|
||||
? ($a ? 0.15 : 0.10)
|
||||
: ($a ? 0.40 : 0.20);
|
||||
{"type":"usd","usd": $pps * $seconds}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
resolution="720p",
|
||||
negative_prompt="",
|
||||
duration_seconds=8,
|
||||
enhance_prompt=True,
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
model="veo-3.0-generate-001",
|
||||
generate_audio=False,
|
||||
):
|
||||
if resolution == "4k" and ("lite" in model or "3.0" in model):
|
||||
raise Exception("4K resolution is not supported by the veo-3.1-lite or veo-3.0 models.")
|
||||
|
||||
model = MODELS_MAP[model]
|
||||
|
||||
instances = [{"prompt": prompt}]
|
||||
if image is not None:
|
||||
image_base64 = tensor_to_base64_string(image)
|
||||
if image_base64:
|
||||
instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
|
||||
|
||||
parameters = {
|
||||
"aspectRatio": aspect_ratio,
|
||||
"personGeneration": person_generation,
|
||||
"durationSeconds": duration_seconds,
|
||||
"enhancePrompt": True,
|
||||
"generateAudio": generate_audio,
|
||||
}
|
||||
if negative_prompt:
|
||||
parameters["negativePrompt"] = negative_prompt
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
if "veo-3.1" in model:
|
||||
parameters["resolution"] = resolution
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||
response_model=VeoGenVidResponse,
|
||||
data=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters,
|
||||
),
|
||||
)
|
||||
|
||||
poll_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||
response_model=VeoGenVidPollResponse,
|
||||
status_extractor=lambda r: "completed" if r.done else "pending",
|
||||
data=VeoGenVidPollRequest(operationName=initial_response.name),
|
||||
poll_interval=9.0,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
if poll_response.error:
|
||||
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||
|
||||
response = poll_response.response
|
||||
filtered_count = response.raiMediaFilteredCount
|
||||
if filtered_count:
|
||||
reasons = response.raiMediaFilteredReasons or []
|
||||
reason_part = f": {reasons[0]}" if reasons else ""
|
||||
raise Exception(
|
||||
f"Content blocked by Google's Responsible AI filters{reason_part} "
|
||||
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
|
||||
)
|
||||
|
||||
if response.videos:
|
||||
video = response.videos[0]
|
||||
if video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@@ -394,7 +483,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
@@ -424,8 +513,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
IO.Image.Input("last_frame", tooltip="End frame"),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
|
||||
default="veo-3.1-fast-generate",
|
||||
options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"],
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
@@ -443,26 +531,20 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
|
||||
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
|
||||
};
|
||||
$m := widgets.model;
|
||||
$ga := (widgets.generate_audio = "true");
|
||||
$r := widgets.resolution;
|
||||
$ga := widgets.generate_audio;
|
||||
$seconds := widgets.duration;
|
||||
$modelKey :=
|
||||
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
|
||||
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
|
||||
"";
|
||||
$audioKey := $ga ? "audio" : "no_audio";
|
||||
$modelPrices := $lookup($prices, $modelKey);
|
||||
$pps := $lookup($modelPrices, $audioKey);
|
||||
($pps != null)
|
||||
? {"type":"usd","usd": $pps * $seconds}
|
||||
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
|
||||
$pps :=
|
||||
$contains($m, "lite")
|
||||
? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03))
|
||||
: $contains($m, "fast")
|
||||
? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08))
|
||||
: ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20));
|
||||
{"type":"usd","usd": $pps * $seconds}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -482,6 +564,9 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
model: str,
|
||||
generate_audio: bool,
|
||||
):
|
||||
if "lite" in model and resolution == "4k":
|
||||
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
|
||||
|
||||
model = MODELS_MAP[model]
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@@ -519,7 +604,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
data=VeoGenVidPollRequest(
|
||||
operationName=initial_response.name,
|
||||
),
|
||||
poll_interval=5.0,
|
||||
poll_interval=9.0,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from .conversions import (
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
resize_video_to_pixel_budget,
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
@@ -90,6 +91,7 @@ __all__ = [
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"resize_mask_to_image",
|
||||
"resize_video_to_pixel_budget",
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
|
||||
@@ -156,6 +156,7 @@ async def poll_op(
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
cls,
|
||||
@@ -176,6 +177,7 @@ async def poll_op(
|
||||
estimated_duration=estimated_duration,
|
||||
cancel_endpoint=cancel_endpoint,
|
||||
cancel_timeout=cancel_timeout,
|
||||
extra_text=extra_text,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@@ -260,6 +262,7 @@ async def poll_op_raw(
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||
@@ -299,6 +302,7 @@ async def poll_op_raw(
|
||||
price=state.price,
|
||||
is_queued=state.is_queued,
|
||||
processing_elapsed_seconds=int(proc_elapsed),
|
||||
extra_text=extra_text,
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception as exc:
|
||||
@@ -389,6 +393,7 @@ async def poll_op_raw(
|
||||
price=state.price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||||
extra_text=extra_text,
|
||||
)
|
||||
return resp_json
|
||||
|
||||
@@ -462,6 +467,7 @@ def _display_time_progress(
|
||||
price: float | None = None,
|
||||
is_queued: bool | None = None,
|
||||
processing_elapsed_seconds: int | None = None,
|
||||
extra_text: str | None = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
@@ -469,7 +475,8 @@ def _display_time_progress(
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||||
_display_text(node_cls, time_line, status=status, price=price)
|
||||
text = f"{time_line}\n\n{extra_text}" if extra_text else time_line
|
||||
_display_text(node_cls, text, status=status, price=price)
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
|
||||
@@ -129,19 +129,35 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
|
||||
"""Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits.
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
|
||||
are rounded down to even values (many codecs require divisible-by-2).
|
||||
"""
|
||||
pixels = src_w * src_h
|
||||
if pixels <= total_pixels:
|
||||
return None
|
||||
scale = math.sqrt(total_pixels / pixels)
|
||||
new_w = max(2, int(src_w * scale))
|
||||
new_h = max(2, int(src_h * scale))
|
||||
new_w -= new_w % 2
|
||||
new_h -= new_h % 2
|
||||
return new_w, new_h
|
||||
|
||||
|
||||
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels.
|
||||
|
||||
Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels``
|
||||
and is compatible with codecs that require even dimensions (e.g. yuv420p).
|
||||
"""
|
||||
samples = image.movedim(-1, 1)
|
||||
dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
|
||||
if dims is None:
|
||||
return image
|
||||
new_w, new_h = dims
|
||||
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
|
||||
@@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
|
||||
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
|
||||
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
|
||||
"""
|
||||
src_w, src_h = video.get_dimensions()
|
||||
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
|
||||
if scale_dims is None:
|
||||
return video
|
||||
return _apply_video_scale(video, scale_dims)
|
||||
|
||||
|
||||
def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video:
|
||||
"""Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass."""
|
||||
out_w, out_h = scale_dims
|
||||
output_buffer = BytesIO()
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
input_source = video.get_stream_source()
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
video_stream = output_container.add_stream("h264", rate=video.get_frame_rate())
|
||||
video_stream.width = out_w
|
||||
video_stream.height = out_h
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
|
||||
audio_stream = None
|
||||
for stream in input_container.streams:
|
||||
if isinstance(stream, av.AudioStream):
|
||||
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
break
|
||||
|
||||
for frame in input_container.decode(video=0):
|
||||
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
if audio_stream is not None:
|
||||
input_container.seek(0)
|
||||
for audio_frame in input_container.decode(audio=0):
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output_container.mux(packet)
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to resize video: {str(e)}") from e
|
||||
|
||||
|
||||
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
|
||||
258
comfy_extras/frame_interpolation_models/film_net.py
Normal file
258
comfy_extras/frame_interpolation_models/film_net.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""FILM: Frame Interpolation for Large Motion (ECCV 2022)."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class FilmConv2d(nn.Module):
|
||||
"""Conv2d with optional LeakyReLU and FILM-style padding."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.even_pad = not size % 2
|
||||
self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype)
|
||||
self.activation = nn.LeakyReLU(0.2) if activation else None
|
||||
|
||||
def forward(self, x):
|
||||
if self.even_pad:
|
||||
x = F.pad(x, (0, 1, 0, 1))
|
||||
x = self.conv(x)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
def _warp_core(image, flow, grid_x, grid_y):
|
||||
dtype = image.dtype
|
||||
H, W = flow.shape[2], flow.shape[3]
|
||||
dx = flow[:, 0].float() / (W * 0.5)
|
||||
dy = flow[:, 1].float() / (H * 0.5)
|
||||
grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3)
|
||||
return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype)
|
||||
|
||||
|
||||
def build_image_pyramid(image, pyramid_levels):
|
||||
pyramid = [image]
|
||||
for _ in range(1, pyramid_levels):
|
||||
image = F.avg_pool2d(image, 2, 2)
|
||||
pyramid.append(image)
|
||||
return pyramid
|
||||
|
||||
|
||||
def flow_pyramid_synthesis(residual_pyramid):
|
||||
flow = residual_pyramid[-1]
|
||||
flow_pyramid = [flow]
|
||||
for residual_flow in residual_pyramid[:-1][::-1]:
|
||||
flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow)
|
||||
flow_pyramid.append(flow)
|
||||
flow_pyramid.reverse()
|
||||
return flow_pyramid
|
||||
|
||||
|
||||
def multiply_pyramid(pyramid, scalar):
|
||||
return [image * scalar[:, None, None, None] for image in pyramid]
|
||||
|
||||
|
||||
def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn):
|
||||
return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)]
|
||||
|
||||
|
||||
def concatenate_pyramids(pyramid1, pyramid2):
|
||||
return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)]
|
||||
|
||||
|
||||
class SubTreeExtractor(nn.Module):
|
||||
def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
convs = []
|
||||
for i in range(n_layers):
|
||||
out_ch = channels << i
|
||||
convs.append(nn.Sequential(
|
||||
FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations),
|
||||
FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations)))
|
||||
in_channels = out_ch
|
||||
self.convs = nn.ModuleList(convs)
|
||||
|
||||
def forward(self, image, n):
|
||||
head = image
|
||||
pyramid = []
|
||||
for i, layer in enumerate(self.convs):
|
||||
head = layer(head)
|
||||
pyramid.append(head)
|
||||
if i < n - 1:
|
||||
head = F.avg_pool2d(head, 2, 2)
|
||||
return pyramid
|
||||
|
||||
|
||||
class FeatureExtractor(nn.Module):
|
||||
def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations)
|
||||
self.sub_levels = sub_levels
|
||||
|
||||
def forward(self, image_pyramid):
|
||||
sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
|
||||
for i in range(len(image_pyramid))]
|
||||
feature_pyramid = []
|
||||
for i in range(len(image_pyramid)):
|
||||
features = sub_pyramids[i][0]
|
||||
for j in range(1, self.sub_levels):
|
||||
if j <= i:
|
||||
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
|
||||
feature_pyramid.append(features)
|
||||
# Free sub-pyramids no longer needed by future levels
|
||||
if i >= self.sub_levels - 1:
|
||||
sub_pyramids[i - self.sub_levels + 1] = None
|
||||
return feature_pyramid
|
||||
|
||||
|
||||
class FlowEstimator(nn.Module):
|
||||
def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self._convs = nn.ModuleList()
|
||||
for _ in range(num_convs):
|
||||
self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations))
|
||||
in_channels = num_filters
|
||||
self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations))
|
||||
self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations))
|
||||
|
||||
def forward(self, features_a, features_b):
|
||||
net = torch.cat([features_a, features_b], dim=1)
|
||||
for conv in self._convs:
|
||||
net = conv(net)
|
||||
return net
|
||||
|
||||
|
||||
class PyramidFlowEstimator(nn.Module):
|
||||
def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
in_channels = filters << 1
|
||||
predictors = []
|
||||
for i in range(len(flow_convs)):
|
||||
predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations))
|
||||
in_channels += filters << (i + 2)
|
||||
self._predictor = predictors[-1]
|
||||
self._predictors = nn.ModuleList(predictors[:-1][::-1])
|
||||
|
||||
def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn):
|
||||
levels = len(feature_pyramid_a)
|
||||
v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
|
||||
residuals = [v]
|
||||
# Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels
|
||||
steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)]
|
||||
steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)]
|
||||
for i, predictor in steps:
|
||||
v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2)
|
||||
v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v))
|
||||
residuals.append(v_residual)
|
||||
v = v.add_(v_residual)
|
||||
residuals.reverse()
|
||||
return residuals
|
||||
|
||||
|
||||
def _get_fusion_channels(level, filters):
|
||||
# Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions
|
||||
return (sum(filters << i for i in range(level)) + 3 + 2) * 2
|
||||
|
||||
|
||||
class Fusion(nn.Module):
|
||||
def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype)
|
||||
self.convs = nn.ModuleList()
|
||||
in_channels = _get_fusion_channels(n_layers, filters)
|
||||
increase = 0
|
||||
for i in range(n_layers)[::-1]:
|
||||
num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
|
||||
self.convs.append(nn.ModuleList([
|
||||
FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations),
|
||||
FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations),
|
||||
FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)]))
|
||||
in_channels = num_filters
|
||||
increase = _get_fusion_channels(i, filters) - num_filters // 2
|
||||
|
||||
def forward(self, pyramid):
|
||||
net = pyramid[-1]
|
||||
for k, layers in enumerate(self.convs):
|
||||
i = len(self.convs) - 1 - k
|
||||
net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest"))
|
||||
net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1)))
|
||||
return self.output_conv(net)
|
||||
|
||||
|
||||
class FILMNet(nn.Module):
|
||||
def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4,
|
||||
filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.pyramid_levels = pyramid_levels
|
||||
self.fusion_pyramid_levels = fusion_pyramid_levels
|
||||
self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations)
|
||||
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations)
|
||||
self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations)
|
||||
self._warp_grids = {}
|
||||
|
||||
def get_dtype(self):
|
||||
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
"""Pre-compute warp grids for all pyramid levels."""
|
||||
if (H, W) in self._warp_grids:
|
||||
return
|
||||
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
|
||||
for _ in range(self.pyramid_levels):
|
||||
self._warp_grids[(H, W)] = (
|
||||
torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device),
|
||||
torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device),
|
||||
)
|
||||
H, W = H // 2, W // 2
|
||||
|
||||
def warp(self, image, flow):
|
||||
grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])]
|
||||
return _warp_core(image, flow, grid_x, grid_y)
|
||||
|
||||
def extract_features(self, img):
|
||||
"""Extract image and feature pyramids for a single frame. Can be cached across pairs."""
|
||||
image_pyramid = build_image_pyramid(img, self.pyramid_levels)
|
||||
feature_pyramid = self.extract(image_pyramid)
|
||||
return image_pyramid, feature_pyramid
|
||||
|
||||
def forward(self, img0, img1, timestep=0.5, cache=None):
|
||||
# FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported)
|
||||
t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep
|
||||
return self.forward_multi_timestep(img0, img1, [t], cache=cache)
|
||||
|
||||
def forward_multi_timestep(self, img0, img1, timesteps, cache=None):
|
||||
"""Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs."""
|
||||
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
|
||||
|
||||
image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0)
|
||||
image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1)
|
||||
|
||||
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
|
||||
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
|
||||
|
||||
# Build warp targets and free full pyramids (only first fpl levels needed from here)
|
||||
fpl = self.fusion_pyramid_levels
|
||||
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
|
||||
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
|
||||
del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
|
||||
|
||||
results = []
|
||||
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
|
||||
for idx in range(len(timesteps)):
|
||||
batch_dt = dt_tensors[idx:idx + 1]
|
||||
bwd_scaled = multiply_pyramid(bwd_flow, batch_dt)
|
||||
fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt)
|
||||
fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp)
|
||||
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
|
||||
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
|
||||
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
|
||||
del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
|
||||
results.append(self.fuse(aligned))
|
||||
del aligned
|
||||
return torch.cat(results, dim=0)
|
||||
128
comfy_extras/frame_interpolation_models/ifnet.py
Normal file
128
comfy_extras/frame_interpolation_models/ifnet.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
def _warp(img, flow, warp_grids):
|
||||
B, _, H, W = img.shape
|
||||
base_grid, flow_div = warp_grids[(H, W)]
|
||||
flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float()
|
||||
grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1)
|
||||
return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
|
||||
self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
|
||||
self.relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.cnn0(x))
|
||||
x = self.relu(self.cnn1(x))
|
||||
x = self.relu(self.cnn2(x))
|
||||
return self.cnn3(x)
|
||||
|
||||
|
||||
class ResConv(nn.Module):
|
||||
def __init__(self, c, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
|
||||
self.relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(torch.addcmul(x, self.conv(x), self.beta))
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Sequential(
|
||||
nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)),
|
||||
nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)))
|
||||
self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)))
|
||||
self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2))
|
||||
|
||||
def forward(self, x, flow=None, scale=1):
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
|
||||
if flow is not None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
|
||||
x = torch.cat((x, flow), 1)
|
||||
feat = self.convblock(self.conv0(x))
|
||||
tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear")
|
||||
return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:]
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
|
||||
block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
|
||||
self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)])
|
||||
self.scale_list = [16, 8, 4, 2, 1]
|
||||
self.pad_align = 64
|
||||
self._warp_grids = {}
|
||||
|
||||
def get_dtype(self):
|
||||
return self.encode.cnn0.weight.dtype
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
if (H, W) in self._warp_grids:
|
||||
return
|
||||
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32),
|
||||
torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij")
|
||||
self._warp_grids[(H, W)] = (
|
||||
torch.stack((grid_x, grid_y), dim=0).unsqueeze(0),
|
||||
torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device))
|
||||
|
||||
def warp(self, img, flow):
|
||||
return _warp(img, flow, self._warp_grids)
|
||||
|
||||
def extract_features(self, img):
|
||||
"""Extract head features for a single frame. Can be cached across pairs."""
|
||||
return self.encode(img)
|
||||
|
||||
def forward(self, img0, img1, timestep=0.5, cache=None):
|
||||
if not isinstance(timestep, torch.Tensor):
|
||||
timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype)
|
||||
|
||||
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
|
||||
|
||||
B = img0.shape[0]
|
||||
f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0)
|
||||
f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1)
|
||||
flow = mask = feat = None
|
||||
warped_img0, warped_img1 = img0, img1
|
||||
for i, block in enumerate(self.blocks):
|
||||
if flow is None:
|
||||
flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i])
|
||||
else:
|
||||
fd, mask, feat = block(
|
||||
torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1),
|
||||
flow, scale=self.scale_list[i])
|
||||
flow = flow.add_(fd)
|
||||
warped_img0 = self.warp(img0, flow[:, :2])
|
||||
warped_img1 = self.warp(img1, flow[:, 2:4])
|
||||
return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask))
|
||||
|
||||
|
||||
def detect_rife_config(state_dict):
|
||||
head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW)
|
||||
channels = []
|
||||
for i in range(5):
|
||||
key = f"blocks.{i}.conv0.1.0.weight"
|
||||
if key in state_dict:
|
||||
channels.append(state_dict[key].shape[0])
|
||||
if len(channels) != 5:
|
||||
raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
|
||||
return head_ch, channels
|
||||
@@ -3,136 +3,136 @@ from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
|
||||
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
class TextEncodeAceStepAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
outputs=[IO.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return io.NodeOutput(conditioning)
|
||||
return IO.NodeOutput(conditioning)
|
||||
|
||||
class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Int.Input("bpm", default=120, min=10, max=300),
|
||||
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
IO.Int.Input("bpm", default=120, min=10, max=300),
|
||||
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
outputs=[IO.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
return IO.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
class EmptyAceStepLatentAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input(
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
class EmptyAceStep15LatentAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
io.Int.Input(
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
IO.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||
length = round((seconds * 48000 / 1920))
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceAudio(io.ComfyNode):
|
||||
class ReferenceAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
return IO.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
IO.Conditioning.Input("conditioning"),
|
||||
IO.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
IO.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
|
||||
def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
return IO.NodeOutput(conditioning)
|
||||
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
|
||||
@@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
|
||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||
|
||||
|
||||
|
||||
211
comfy_extras/nodes_frame_interpolation.py
Normal file
211
comfy_extras/nodes_frame_interpolation.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy import model_management
|
||||
from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config
|
||||
from comfy_extras.frame_interpolation_models.film_net import FILMNet
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
FrameInterpolationModel = io.Custom("INTERP_MODEL")
|
||||
|
||||
|
||||
class FrameInterpolationModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FrameInterpolationModelLoader",
|
||||
display_name="Load Frame Interpolation Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"),
|
||||
tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."),
|
||||
],
|
||||
outputs=[
|
||||
FrameInterpolationModel.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
|
||||
model = cls._detect_and_load(sd)
|
||||
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
|
||||
model.eval().to(dtype)
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
model,
|
||||
load_device=model_management.get_torch_device(),
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
)
|
||||
return io.NodeOutput(patcher)
|
||||
|
||||
@classmethod
|
||||
def _detect_and_load(cls, sd):
|
||||
# Try FILM
|
||||
if "extract.extract_sublevels.convs.0.0.conv.weight" in sd:
|
||||
model = FILMNet()
|
||||
model.load_state_dict(sd)
|
||||
return model
|
||||
|
||||
# Try RIFE (needs key remapping for raw checkpoints)
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""})
|
||||
key_map = {}
|
||||
for k in sd:
|
||||
for i in range(5):
|
||||
if k.startswith(f"block{i}."):
|
||||
key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}"
|
||||
if key_map:
|
||||
sd = {key_map.get(k, k): v for k, v in sd.items()}
|
||||
sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))}
|
||||
|
||||
try:
|
||||
head_ch, channels = detect_rife_config(sd)
|
||||
except (KeyError, ValueError):
|
||||
raise ValueError("Unrecognized frame interpolation model format")
|
||||
model = IFNet(head_ch=head_ch, channels=channels)
|
||||
model.load_state_dict(sd)
|
||||
return model
|
||||
|
||||
|
||||
class FrameInterpolate(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FrameInterpolate",
|
||||
display_name="Frame Interpolate",
|
||||
category="image/video",
|
||||
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
|
||||
inputs=[
|
||||
FrameInterpolationModel.Input("interp_model"),
|
||||
io.Image.Input("images"),
|
||||
io.Int.Input("multiplier", default=2, min=2, max=16),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, interp_model, images, multiplier) -> io.NodeOutput:
|
||||
offload_device = model_management.intermediate_device()
|
||||
|
||||
num_frames = images.shape[0]
|
||||
if num_frames < 2 or multiplier < 2:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
model_management.load_model_gpu(interp_model)
|
||||
device = interp_model.load_device
|
||||
dtype = interp_model.model_dtype()
|
||||
inference_model = interp_model.model
|
||||
|
||||
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
activation_mem = H * W * 3 * images.element_size() * 20
|
||||
model_management.free_memory(activation_mem, device)
|
||||
align = getattr(inference_model, "pad_align", 1)
|
||||
|
||||
# Prepare a single padded frame on device for determining output dimensions
|
||||
def prepare_frame(idx):
|
||||
frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device)
|
||||
if align > 1:
|
||||
from comfy.ldm.common_dit import pad_to_patch_size
|
||||
frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect")
|
||||
return frame
|
||||
|
||||
# Count total interpolation passes for progress bar
|
||||
total_pairs = num_frames - 1
|
||||
num_interp = multiplier - 1
|
||||
total_steps = total_pairs * num_interp
|
||||
pbar = comfy.utils.ProgressBar(total_steps)
|
||||
tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation")
|
||||
|
||||
batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit)
|
||||
t_values = [t / multiplier for t in range(1, multiplier)]
|
||||
|
||||
out_dtype = model_management.intermediate_dtype()
|
||||
total_out_frames = total_pairs * multiplier + 1
|
||||
result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device)
|
||||
result[0] = images[0].movedim(-1, 0).to(out_dtype)
|
||||
out_idx = 1
|
||||
|
||||
# Pre-compute timestep tensor on device (padded dimensions needed)
|
||||
sample = prepare_frame(0)
|
||||
pH, pW = sample.shape[2], sample.shape[3]
|
||||
ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
|
||||
ts_full = ts_full.expand(-1, 1, pH, pW)
|
||||
del sample
|
||||
|
||||
multi_fn = getattr(inference_model, "forward_multi_timestep", None)
|
||||
feat_cache = {}
|
||||
prev_frame = None
|
||||
|
||||
try:
|
||||
for i in range(total_pairs):
|
||||
img0_single = prev_frame if prev_frame is not None else prepare_frame(i)
|
||||
img1_single = prepare_frame(i + 1)
|
||||
prev_frame = img1_single
|
||||
|
||||
# Cache features: img1 of pair N becomes img0 of pair N+1
|
||||
feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
|
||||
feat_cache["img1"] = inference_model.extract_features(img1_single)
|
||||
feat_cache["next"] = feat_cache["img1"]
|
||||
|
||||
used_multi = False
|
||||
if multi_fn is not None:
|
||||
# Models with timestep-independent flow can compute it once for all timesteps
|
||||
try:
|
||||
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
|
||||
result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype)
|
||||
out_idx += num_interp
|
||||
pbar.update(num_interp)
|
||||
tqdm_bar.update(num_interp)
|
||||
used_multi = True
|
||||
except model_management.OOM_EXCEPTION:
|
||||
model_management.soft_empty_cache()
|
||||
multi_fn = None # fall through to single-timestep path
|
||||
|
||||
if not used_multi:
|
||||
j = 0
|
||||
while j < num_interp:
|
||||
b = min(batch, num_interp - j)
|
||||
try:
|
||||
img0 = img0_single.expand(b, -1, -1, -1)
|
||||
img1 = img1_single.expand(b, -1, -1, -1)
|
||||
mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache)
|
||||
result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype)
|
||||
out_idx += b
|
||||
pbar.update(b)
|
||||
tqdm_bar.update(b)
|
||||
j += b
|
||||
except model_management.OOM_EXCEPTION:
|
||||
if batch <= 1:
|
||||
raise
|
||||
batch = max(1, batch // 2)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype)
|
||||
out_idx += 1
|
||||
finally:
|
||||
tqdm_bar.close()
|
||||
|
||||
# BCHW -> BHWC
|
||||
result = result.movedim(1, -1).clamp_(0.0, 1.0)
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class FrameInterpolationExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
FrameInterpolationModelLoader,
|
||||
FrameInterpolate,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> FrameInterpolationExtension:
|
||||
return FrameInterpolationExtension()
|
||||
@@ -3,9 +3,8 @@ import comfy.utils
|
||||
import comfy.model_management
|
||||
import torch
|
||||
|
||||
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
from comfy_extras.nodes_audio import VAEEncodeAudio
|
||||
|
||||
class LTXVAudioVAELoader(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
vae.throw_exception_if_invalid()
|
||||
|
||||
return io.NodeOutput(vae)
|
||||
|
||||
|
||||
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
@@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latents = audio_vae.encode(audio)
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": int(audio_vae.sample_rate),
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
def execute(cls, audio, audio_vae) -> io.NodeOutput:
|
||||
return super().execute(audio_vae, audio)
|
||||
|
||||
|
||||
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
@@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
def execute(cls, samples, audio_vae) -> io.NodeOutput:
|
||||
audio_latent = samples["samples"]
|
||||
if audio_latent.is_nested:
|
||||
audio_latent = audio_latent.unbind()[-1]
|
||||
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||
audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"waveform": audio,
|
||||
@@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
frames_number: int,
|
||||
frame_rate: int,
|
||||
batch_size: int,
|
||||
audio_vae: AudioVAE,
|
||||
audio_vae,
|
||||
) -> io.NodeOutput:
|
||||
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||
|
||||
assert audio_vae is not None, "Audio VAE model is required"
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.sample_rate)
|
||||
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
audio_latents = torch.zeros(
|
||||
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||
|
||||
529
comfy_extras/nodes_sam3.py
Normal file
529
comfy_extras/nodes_sam3.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""
|
||||
SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking.
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
import av
|
||||
from fractions import Fraction
|
||||
|
||||
|
||||
def _extract_text_prompts(conditioning, device, dtype):
|
||||
"""Extract list of (text_embeddings, text_mask) from conditioning."""
|
||||
cond_meta = conditioning[0][1]
|
||||
multi = cond_meta.get("sam3_multi_cond")
|
||||
prompts = []
|
||||
if multi is not None:
|
||||
for entry in multi:
|
||||
emb = entry["cond"].to(device=device, dtype=dtype)
|
||||
mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None
|
||||
if mask is None:
|
||||
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
|
||||
prompts.append((emb, mask, entry.get("max_detections", 1)))
|
||||
else:
|
||||
emb = conditioning[0][0].to(device=device, dtype=dtype)
|
||||
mask = cond_meta.get("attention_mask")
|
||||
if mask is not None:
|
||||
mask = mask.to(device)
|
||||
else:
|
||||
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
|
||||
prompts.append((emb, mask, 1))
|
||||
return prompts
|
||||
|
||||
|
||||
def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations):
|
||||
"""Refine a coarse detector mask via SAM decoder, cropping to the detection box.
|
||||
|
||||
Returns: [1, H, W] binary mask
|
||||
"""
|
||||
def _coarse_fallback():
|
||||
return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W),
|
||||
mode="bilinear", align_corners=False)[0] > 0).float()
|
||||
|
||||
if iterations <= 0:
|
||||
return _coarse_fallback()
|
||||
|
||||
pad_frac = 0.1
|
||||
x1, y1, x2, y2 = box_xyxy.tolist()
|
||||
bw, bh = x2 - x1, y2 - y1
|
||||
cx1 = max(0, int(x1 - bw * pad_frac))
|
||||
cy1 = max(0, int(y1 - bh * pad_frac))
|
||||
cx2 = min(W, int(x2 + bw * pad_frac))
|
||||
cy2 = min(H, int(y2 + bh * pad_frac))
|
||||
if cx2 <= cx1 or cy2 <= cy1:
|
||||
return _coarse_fallback()
|
||||
|
||||
crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3]
|
||||
crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
|
||||
crop_frame = crop_1008.to(device=device, dtype=dtype)
|
||||
crop_h, crop_w = cy2 - cy1, cx2 - cx1
|
||||
|
||||
# Crop coarse mask and refine via SAM on the cropped image
|
||||
mask_h, mask_w = coarse_mask.shape[-2:]
|
||||
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
|
||||
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
|
||||
if mx2 <= mx1 or my2 <= my1:
|
||||
return _coarse_fallback()
|
||||
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
|
||||
for _ in range(iterations):
|
||||
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
|
||||
mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input)
|
||||
|
||||
refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False)
|
||||
full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype)
|
||||
full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop
|
||||
coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)
|
||||
return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float()
|
||||
|
||||
|
||||
|
||||
class SAM3_Detect(io.ComfyNode):
|
||||
"""Open-vocabulary detection and segmentation using text, box, or point prompts."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_Detect",
|
||||
display_name="SAM3 Detect",
|
||||
category="detection/",
|
||||
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
||||
inputs=[
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Image.Input("image", display_name="image"),
|
||||
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"),
|
||||
io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"),
|
||||
io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
|
||||
io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
|
||||
io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"),
|
||||
io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"),
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output("masks"),
|
||||
io.BoundingBox.Output("bboxes"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput:
|
||||
B, H, W, C = image.shape
|
||||
image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
|
||||
|
||||
# Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors.
|
||||
# Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame).
|
||||
def _boxes_to_tensor(box_list):
|
||||
coords = []
|
||||
for d in box_list:
|
||||
cx = (d["x"] + d["width"] / 2) / W
|
||||
cy = (d["y"] + d["height"] / 2) / H
|
||||
coords.append([cx, cy, d["width"] / W, d["height"] / H])
|
||||
return torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
|
||||
|
||||
per_frame_boxes = None
|
||||
if bboxes is not None:
|
||||
if isinstance(bboxes, dict):
|
||||
# Single box → same for all frames
|
||||
shared = _boxes_to_tensor([bboxes])
|
||||
per_frame_boxes = [shared] * B
|
||||
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
|
||||
# list[list[dict]] → per-frame boxes
|
||||
per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes]
|
||||
# Pad to B if fewer frames provided
|
||||
while len(per_frame_boxes) < B:
|
||||
per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None)
|
||||
elif isinstance(bboxes, list) and len(bboxes) > 0:
|
||||
# list[dict] → same boxes for all frames
|
||||
shared = _boxes_to_tensor(bboxes)
|
||||
per_frame_boxes = [shared] * B
|
||||
|
||||
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
|
||||
pos_pts = json.loads(positive_coords) if positive_coords else []
|
||||
neg_pts = json.loads(negative_coords) if negative_coords else []
|
||||
has_points = len(pos_pts) > 0 or len(neg_pts) > 0
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
dtype = model.model.get_dtype()
|
||||
sam3_model = model.model.diffusion_model
|
||||
|
||||
# Build point inputs for tracker SAM decoder path
|
||||
point_inputs = None
|
||||
if has_points:
|
||||
all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \
|
||||
[[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts]
|
||||
all_labels = [1] * len(pos_pts) + [0] * len(neg_pts)
|
||||
point_inputs = {
|
||||
"point_coords": torch.tensor([all_coords], dtype=dtype, device=device),
|
||||
"point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device),
|
||||
}
|
||||
|
||||
cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else []
|
||||
has_text = len(cond_list) > 0
|
||||
|
||||
# Run per-image through detector (text/boxes) and/or tracker (points)
|
||||
all_bbox_dicts = []
|
||||
all_masks = []
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
|
||||
for b in range(B):
|
||||
frame = image_in[b:b+1].to(device=device, dtype=dtype)
|
||||
b_boxes = None
|
||||
if per_frame_boxes is not None and per_frame_boxes[b] is not None:
|
||||
b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype)
|
||||
|
||||
frame_bbox_dicts = []
|
||||
frame_masks = []
|
||||
|
||||
# Point prompts: tracker SAM decoder path with iterative refinement
|
||||
if point_inputs is not None:
|
||||
mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs)
|
||||
for _ in range(max(0, refine_iterations - 1)):
|
||||
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
|
||||
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
|
||||
frame_masks.append((mask[0] > 0).float())
|
||||
|
||||
# Box prompts: SAM decoder path (segment inside each box)
|
||||
if b_boxes is not None and not has_text:
|
||||
for box_cxcywh in b_boxes[0]:
|
||||
cx, cy, bw, bh = box_cxcywh.tolist()
|
||||
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
|
||||
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
|
||||
[(cx + bw/2) * 1008, (cy + bh/2) * 1008]]],
|
||||
device=device, dtype=dtype)
|
||||
mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box)
|
||||
for _ in range(max(0, refine_iterations - 1)):
|
||||
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
|
||||
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
|
||||
frame_masks.append((mask[0] > 0).float())
|
||||
|
||||
# Text prompts: run detector per text prompt (each detects one category)
|
||||
for text_embeddings, text_mask, max_det in cond_list:
|
||||
results = sam3_model(
|
||||
frame, text_embeddings=text_embeddings, text_mask=text_mask,
|
||||
boxes=b_boxes, threshold=threshold, orig_size=(H, W))
|
||||
|
||||
pred_boxes = results["boxes"][0]
|
||||
scores = results["scores"][0]
|
||||
masks = results["masks"][0]
|
||||
|
||||
probs = scores.sigmoid()
|
||||
keep = probs > threshold
|
||||
kept_boxes = pred_boxes[keep].cpu()
|
||||
kept_scores = probs[keep].cpu()
|
||||
kept_masks = masks[keep]
|
||||
|
||||
order = kept_scores.argsort(descending=True)[:max_det]
|
||||
kept_boxes = kept_boxes[order]
|
||||
kept_scores = kept_scores[order]
|
||||
kept_masks = kept_masks[order]
|
||||
|
||||
for box, score in zip(kept_boxes, kept_scores):
|
||||
frame_bbox_dicts.append({
|
||||
"x": float(box[0]), "y": float(box[1]),
|
||||
"width": float(box[2] - box[0]), "height": float(box[3] - box[1]),
|
||||
"score": float(score),
|
||||
})
|
||||
for m, box in zip(kept_masks, kept_boxes):
|
||||
frame_masks.append(_refine_mask(
|
||||
sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations))
|
||||
|
||||
all_bbox_dicts.append(frame_bbox_dicts)
|
||||
if len(frame_masks) > 0:
|
||||
combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W]
|
||||
if individual_masks:
|
||||
all_masks.append(combined)
|
||||
else:
|
||||
all_masks.append((combined > 0).any(dim=0).float())
|
||||
else:
|
||||
if individual_masks:
|
||||
all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device()))
|
||||
else:
|
||||
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
||||
pbar.update(1)
|
||||
|
||||
idev = comfy.model_management.intermediate_device()
|
||||
all_masks = [m.to(idev) for m in all_masks]
|
||||
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
|
||||
return io.NodeOutput(mask_out, all_bbox_dicts)
|
||||
|
||||
|
||||
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
||||
|
||||
class SAM3_VideoTrack(io.ComfyNode):
|
||||
"""Track objects across video frames using SAM3's memory-based tracker."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_VideoTrack",
|
||||
display_name="SAM3 Video Track",
|
||||
category="detection/",
|
||||
search_aliases=["sam3", "video", "track", "propagate"],
|
||||
inputs=[
|
||||
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
|
||||
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
|
||||
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
|
||||
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
|
||||
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
|
||||
],
|
||||
outputs=[
|
||||
SAM3TrackData.Output("track_data", display_name="track_data"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput:
|
||||
N, H, W, C = images.shape
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
dtype = model.model.get_dtype()
|
||||
sam3_model = model.model.diffusion_model
|
||||
|
||||
frames = images[..., :3].movedim(-1, 1)
|
||||
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
|
||||
|
||||
init_masks = None
|
||||
if initial_mask is not None:
|
||||
init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(N)
|
||||
|
||||
text_prompts = None
|
||||
if conditioning is not None and len(conditioning) > 0:
|
||||
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
|
||||
elif initial_mask is None:
|
||||
raise ValueError("Either initial_mask or conditioning must be provided")
|
||||
|
||||
result = sam3_model.forward_video(
|
||||
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
|
||||
new_det_thresh=detection_threshold, max_objects=max_objects,
|
||||
detect_interval=detect_interval)
|
||||
result["orig_size"] = (H, W)
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class SAM3_TrackPreview(io.ComfyNode):
|
||||
"""Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_TrackPreview",
|
||||
display_name="SAM3 Track Preview",
|
||||
category="detection/",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||
io.Image.Input("images", display_name="images", optional=True),
|
||||
io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05),
|
||||
io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0),
|
||||
],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
COLORS = [
|
||||
(0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16),
|
||||
(0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5),
|
||||
(0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84),
|
||||
]
|
||||
|
||||
# 5x3 bitmap font atlas for digits 0-9 [10, 5, 3]
|
||||
_glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow)
|
||||
|
||||
@staticmethod
|
||||
def _get_glyphs(device, scale=3):
|
||||
key = (device, scale)
|
||||
if key in SAM3_TrackPreview._glyph_cache:
|
||||
return SAM3_TrackPreview._glyph_cache[key]
|
||||
atlas = torch.tensor([
|
||||
[[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]],
|
||||
[[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]],
|
||||
[[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]],
|
||||
[[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]],
|
||||
[[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]],
|
||||
[[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]],
|
||||
[[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]],
|
||||
[[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]],
|
||||
[[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]],
|
||||
[[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]],
|
||||
], dtype=torch.bool)
|
||||
glyphs, outlines = [], []
|
||||
for d in range(10):
|
||||
g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1)
|
||||
padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1))
|
||||
o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0)
|
||||
glyphs.append(g.to(device))
|
||||
outlines.append(o.to(device))
|
||||
gh, gw = glyphs[0].shape
|
||||
oh, ow = outlines[0].shape
|
||||
SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow)
|
||||
return SAM3_TrackPreview._glyph_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _draw_number_gpu(frame, number, cx, cy, color, scale=3):
|
||||
"""Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline."""
|
||||
H, W = frame.shape[:2]
|
||||
device = frame.device
|
||||
glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale)
|
||||
color_t = torch.tensor(color, device=device, dtype=frame.dtype)
|
||||
digs = [int(d) for d in str(number)]
|
||||
total_w = len(digs) * (gw + scale) - scale
|
||||
x0 = cx - total_w // 2
|
||||
y0 = cy - gh // 2
|
||||
for i, d in enumerate(digs):
|
||||
dx = x0 + i * (gw + scale)
|
||||
# Black outline
|
||||
oy0, ox0 = y0 - 1, dx - 1
|
||||
osy1, osx1 = max(0, -oy0), max(0, -ox0)
|
||||
osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0)
|
||||
if osy2 > osy1 and osx2 > osx1:
|
||||
fy1, fx1 = oy0 + osy1, ox0 + osx1
|
||||
frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0
|
||||
# Colored fill
|
||||
sy1, sx1 = max(0, -y0), max(0, -dx)
|
||||
sy2, sx2 = min(gh, H - y0), min(gw, W - dx)
|
||||
if sy2 > sy1 and sx2 > sx1:
|
||||
fy1, fx1 = y0 + sy1, dx + sx1
|
||||
frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t
|
||||
|
||||
@classmethod
|
||||
def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput:
|
||||
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
if images is not None:
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
if packed is None:
|
||||
N, N_obj = track_data["n_frames"], 0
|
||||
else:
|
||||
N, N_obj = packed.shape[0], packed.shape[1]
|
||||
|
||||
import uuid
|
||||
gpu = comfy.model_management.get_torch_device()
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4"
|
||||
filepath = os.path.join(temp_dir, filename)
|
||||
with av.open(filepath, mode='w') as output:
|
||||
stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000))
|
||||
stream.width = W
|
||||
stream.height = H
|
||||
stream.pix_fmt = 'yuv420p'
|
||||
|
||||
frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8)
|
||||
frame_np = frame_cpu.numpy()
|
||||
if N_obj > 0:
|
||||
colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)],
|
||||
device=gpu, dtype=torch.float32)
|
||||
grid_y = torch.arange(H, device=gpu).view(1, H, 1)
|
||||
grid_x = torch.arange(W, device=gpu).view(1, 1, W)
|
||||
for t in range(N):
|
||||
if images is not None and t < images.shape[0]:
|
||||
frame = images[t].clone()
|
||||
else:
|
||||
frame = torch.zeros(H, W, 3)
|
||||
|
||||
if N_obj > 0:
|
||||
frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool
|
||||
frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0]
|
||||
frame_gpu = frame.to(gpu)
|
||||
bool_masks = frame_masks > 0.5
|
||||
any_mask = bool_masks.any(dim=0)
|
||||
if any_mask.any():
|
||||
obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0)
|
||||
color_overlay = colors_t[obj_idx_map]
|
||||
mask_3d = any_mask.unsqueeze(-1)
|
||||
frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu)
|
||||
area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1)
|
||||
cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area
|
||||
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
|
||||
has = area > 1
|
||||
scores = track_data.get("scores", [])
|
||||
for obj_idx in range(N_obj):
|
||||
if has[obj_idx]:
|
||||
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
|
||||
color = cls.COLORS[obj_idx % len(cls.COLORS)]
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
|
||||
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
|
||||
_cx, _cy + 5 * 3 + 3, color, scale=2)
|
||||
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
|
||||
else:
|
||||
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
|
||||
|
||||
vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
||||
output.mux(stream.encode(vframe.reformat(format='yuv420p')))
|
||||
output.mux(stream.encode(None))
|
||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)]))
|
||||
|
||||
|
||||
class SAM3_TrackToMask(io.ComfyNode):
|
||||
"""Select tracked objects by index and output as mask."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_TrackToMask",
|
||||
display_name="SAM3 Track to Mask",
|
||||
category="detection/",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||
io.String.Input("object_indices", display_name="object_indices", default="",
|
||||
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."),
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output("masks", display_name="masks"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, track_data, object_indices="") -> io.NodeOutput:
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
|
||||
if packed is None:
|
||||
N = track_data["n_frames"]
|
||||
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
||||
|
||||
N, N_obj = packed.shape[0], packed.shape[1]
|
||||
|
||||
if object_indices.strip():
|
||||
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
||||
indices = [i for i in indices if 0 <= i < N_obj]
|
||||
else:
|
||||
indices = list(range(N_obj))
|
||||
|
||||
if not indices:
|
||||
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
||||
|
||||
selected = packed[:, indices]
|
||||
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
|
||||
union = binary.any(dim=1, keepdim=True).float()
|
||||
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
||||
return io.NodeOutput(mask_out)
|
||||
|
||||
|
||||
class SAM3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SAM3_Detect,
|
||||
SAM3_VideoTrack,
|
||||
SAM3_TrackPreview,
|
||||
SAM3_TrackToMask,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SAM3Extension:
|
||||
return SAM3Extension()
|
||||
@@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
|
||||
|
||||
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
|
||||
4
main.py
4
main.py
@@ -9,6 +9,8 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
import itertools
|
||||
@@ -27,8 +29,6 @@ if __name__ == "__main__":
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1
|
||||
comfyui_manager==4.2.1
|
||||
|
||||
4
nodes.py
4
nodes.py
@@ -2457,7 +2457,9 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
"nodes_rtdetr.py"
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.11
|
||||
comfyui-workflow-templates==0.9.57
|
||||
comfyui-frontend-package==1.42.14
|
||||
comfyui-workflow-templates==0.9.59
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@@ -39,7 +39,7 @@ def get_required_packages_versions():
|
||||
if len(s) == 2:
|
||||
version_str = s[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
logging.debug(f"Invalid version format for {s[0]} in requirements.txt: {version_str}")
|
||||
continue
|
||||
out[s[0]] = version_str
|
||||
return out.copy()
|
||||
|
||||
Reference in New Issue
Block a user