mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 17:07:28 +00:00
Compare commits
9 Commits
master
...
pysssss/an
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8114516ee6 | ||
|
|
3eb624ce6c | ||
|
|
54ff5464bd | ||
|
|
333ff2e8a0 | ||
|
|
c821d8ee2a | ||
|
|
27b6f8a927 | ||
|
|
9ad848bd59 | ||
|
|
efe6439ad0 | ||
|
|
8d76bb94fd |
@@ -38,12 +38,9 @@ void main() {
|
||||
// GIMP order: per-channel curves first, then RGB master curve.
|
||||
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||
// dest = colors_curve( channel_curve( src ) )
|
||||
float tmp_r = applyCurve(u_curve1, color.r);
|
||||
float tmp_g = applyCurve(u_curve2, color.g);
|
||||
float tmp_b = applyCurve(u_curve3, color.b);
|
||||
color.r = applyCurve(u_curve0, tmp_r);
|
||||
color.g = applyCurve(u_curve0, tmp_g);
|
||||
color.b = applyCurve(u_curve0, tmp_b);
|
||||
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
|
||||
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
|
||||
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
|
||||
|
||||
fragColor0 = vec4(color.rgb, color.a);
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
CACHE_RAM_AUTO_GB = -1.0
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
|
||||
@@ -1,725 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
COCO_CLASSES = [
|
||||
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
|
||||
'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat',
|
||||
'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack',
|
||||
'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball',
|
||||
'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket',
|
||||
'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple',
|
||||
'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair',
|
||||
'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse',
|
||||
'remote','keyboard','cell phone','microwave','oven','toaster','sink',
|
||||
'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush',
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HGNetv2 backbone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConvBNAct(nn.Module):
|
||||
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
|
||||
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
||||
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||
self.act = nn.ReLU() if use_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
class LightConvBNAct(nn.Module):
|
||||
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x))
|
||||
|
||||
class _StemBlock(nn.Module):
|
||||
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
|
||||
# padding is applied manually in forward (matching PaddlePaddle original)
|
||||
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem1(x)
|
||||
x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a
|
||||
x2 = self.stem2a(x)
|
||||
x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b
|
||||
x2 = self.stem2b(x2)
|
||||
x1 = self.pool(x)
|
||||
return self.stem4(self.stem3(torch.cat([x1, x2], 1)))
|
||||
|
||||
|
||||
class _HG_Block(nn.Module):
|
||||
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.residual = residual
|
||||
if light:
|
||||
self.layers = nn.ModuleList(
|
||||
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||
total = ic + layer_num * mc
|
||||
|
||||
self.aggregation = nn.Sequential(
|
||||
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
|
||||
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
outs = [x]
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outs.append(x)
|
||||
x = self.aggregation(torch.cat(outs, 1))
|
||||
return x + identity if self.residual else x
|
||||
|
||||
|
||||
class _HG_Stage(nn.Module):
|
||||
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
|
||||
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
if downsample:
|
||||
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
self.blocks = nn.Sequential(*[
|
||||
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
|
||||
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
|
||||
for i in range(num_blocks)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
return self.blocks(self.downsample(x))
|
||||
|
||||
|
||||
class HGNetv2(nn.Module):
|
||||
# B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers]
|
||||
_STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6],
|
||||
[128, 128, 512, 2, True, False, 3, 6],
|
||||
[512, 256, 1024, 5, True, True, 5, 6],
|
||||
[1024,512, 2048, 2, True, True, 5, 6]]
|
||||
|
||||
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
|
||||
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
|
||||
self.return_idx = list(return_idx)
|
||||
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.stem(x)
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
if i in self.return_idx:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConvNormLayer(nn.Module):
|
||||
"""Conv→act (expects pre-fused BN weights)."""
|
||||
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
p = (k - 1) // 2 if padding is None else padding
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
|
||||
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class VGGBlock(nn.Module):
|
||||
"""Rep-VGG block (expects pre-fused weights)."""
|
||||
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class CSPLayer(nn.Module):
|
||||
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
h = int(oc * expansion)
|
||||
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
|
||||
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
|
||||
|
||||
|
||||
class RepNCSPELAN4(nn.Module):
|
||||
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
|
||||
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.c = c3 // 2
|
||||
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
y = list(self.cv1(x).split((self.c, self.c), 1))
|
||||
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
||||
return self.cv4(torch.cat(y, 1))
|
||||
|
||||
|
||||
class SCDown(nn.Module):
|
||||
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
|
||||
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query, key, value, attn_mask=None):
|
||||
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
|
||||
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
|
||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class _TransformerEncoderLayer(nn.Module):
|
||||
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
|
||||
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, src, src_mask=None, pos_embed=None):
|
||||
q = k = src if pos_embed is None else src + pos_embed
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
|
||||
src = self.norm1(src + src2)
|
||||
src2 = self.linear2(self.activation(self.linear1(src)))
|
||||
return self.norm2(src + src2)
|
||||
|
||||
|
||||
class _TransformerEncoder(nn.Module):
|
||||
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
|
||||
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, src, src_mask=None, pos_embed=None):
|
||||
for layer in self.layers:
|
||||
src = layer(src, src_mask=src_mask, pos_embed=pos_embed)
|
||||
return src
|
||||
|
||||
|
||||
class HybridEncoder(nn.Module):
|
||||
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
|
||||
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_channels = list(in_channels)
|
||||
self.feat_strides = list(feat_strides)
|
||||
self.hidden_dim = hidden_dim
|
||||
self.use_encoder_idx = list(use_encoder_idx)
|
||||
self.pe_temperature = pe_temperature
|
||||
self.eval_spatial_size = eval_spatial_size
|
||||
self.out_channels = [hidden_dim] * len(in_channels)
|
||||
self.out_strides = list(feat_strides)
|
||||
|
||||
# channel projection (expects pre-fused weights)
|
||||
self.input_proj = nn.ModuleList([
|
||||
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
|
||||
for ch in in_channels
|
||||
])
|
||||
|
||||
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
|
||||
self.encoder = nn.ModuleList([
|
||||
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(use_encoder_idx))
|
||||
])
|
||||
|
||||
nb = round(3 * depth_mult)
|
||||
exp = expansion
|
||||
|
||||
# top-down FPN (dfine: lateral conv has no act)
|
||||
self.lateral_convs = nn.ModuleList(
|
||||
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
self.fpn_blocks = nn.ModuleList(
|
||||
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
|
||||
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
|
||||
self.downsample_convs = nn.ModuleList(
|
||||
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
self.pan_blocks = nn.ModuleList(
|
||||
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
|
||||
# cache positional embeddings for fixed spatial size
|
||||
if eval_spatial_size:
|
||||
for idx in self.use_encoder_idx:
|
||||
stride = self.feat_strides[idx]
|
||||
pe = self._build_pe(eval_spatial_size[1] // stride,
|
||||
eval_spatial_size[0] // stride,
|
||||
hidden_dim, pe_temperature)
|
||||
setattr(self, f'pos_embed{idx}', pe)
|
||||
|
||||
@staticmethod
|
||||
def _build_pe(w, h, dim=256, temp=10000.):
|
||||
assert dim % 4 == 0
|
||||
gw = torch.arange(w, dtype=torch.float32)
|
||||
gh = torch.arange(h, dtype=torch.float32)
|
||||
gw, gh = torch.meshgrid(gw, gh, indexing='ij')
|
||||
pdim = dim // 4
|
||||
omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim))
|
||||
ow = gw.flatten()[:, None] @ omega[None]
|
||||
oh = gh.flatten()[:, None] @ omega[None]
|
||||
return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None]
|
||||
|
||||
def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||
|
||||
for i, enc_idx in enumerate(self.use_encoder_idx):
|
||||
h, w = proj[enc_idx].shape[2:]
|
||||
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
|
||||
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
|
||||
for layer in self.encoder[i].layers:
|
||||
src = layer(src, pos_embed=pe)
|
||||
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
|
||||
|
||||
n = len(self.in_channels)
|
||||
inner = [proj[-1]]
|
||||
for k in range(n - 1, 0, -1):
|
||||
j = n - 1 - k
|
||||
top = self.lateral_convs[j](inner[0])
|
||||
inner[0] = top
|
||||
up = F.interpolate(top, scale_factor=2., mode='nearest')
|
||||
inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1)))
|
||||
|
||||
outs = [inner[0]]
|
||||
for k in range(n - 1):
|
||||
outs.append(self.pan_blocks[k](
|
||||
torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1)))
|
||||
return outs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder — DFINETransformer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
value : list of per-level tensors [bs*n_head, c, h_l, w_l]
|
||||
sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1]
|
||||
attention_weights : [bs, Lq, n_head, sum(pts)]
|
||||
"""
|
||||
_, c = value[0].shape[:2] # bs*n_head, c
|
||||
_, Lq, n_head, _, _ = sampling_locations.shape
|
||||
bs = sampling_locations.shape[0]
|
||||
n_h = n_head
|
||||
|
||||
grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2]
|
||||
grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2]
|
||||
grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2]
|
||||
|
||||
sampled = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
val_l = value[lvl].reshape(bs * n_h, c, h, w)
|
||||
sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l]
|
||||
|
||||
attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts]
|
||||
attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts]
|
||||
out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq]
|
||||
out = out.reshape(bs, n_h * c, Lq)
|
||||
return out.permute(0, 2, 1) # [bs, Lq, hidden]
|
||||
|
||||
|
||||
class MSDeformableAttention(nn.Module):
|
||||
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim, self.num_heads = embed_dim, num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
pts = num_points if isinstance(num_points, list) else [num_points] * num_levels
|
||||
self.num_points_list = pts
|
||||
self.offset_scale = offset_scale
|
||||
total = num_heads * sum(pts)
|
||||
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
|
||||
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
|
||||
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query, ref_pts, value, spatial_shapes):
|
||||
bs, Lq = query.shape[:2]
|
||||
offsets = self.sampling_offsets(query).reshape(
|
||||
bs, Lq, self.num_heads, sum(self.num_points_list), 2)
|
||||
attn_w = F.softmax(
|
||||
self.attention_weights(query).reshape(
|
||||
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
|
||||
scale = self.num_points_scale.to(query).unsqueeze(-1)
|
||||
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
|
||||
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
|
||||
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
def __init__(self, d_model, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
|
||||
return self.norm(g1 * x1 + g2 * x2)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
|
||||
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU()
|
||||
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
|
||||
q = k = target if query_pos is None else target + query_pos
|
||||
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
|
||||
target = self.norm1(target + t2)
|
||||
t2 = self.cross_attn(
|
||||
target if query_pos is None else target + query_pos,
|
||||
ref_pts, value, spatial_shapes)
|
||||
target = self.gateway(target, t2)
|
||||
t2 = self.linear2(self.activation(self.linear1(target)))
|
||||
target = self.norm3((target + t2).clamp(-65504, 65504))
|
||||
return target
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FDR utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def weighting_function(reg_max, up, reg_scale):
|
||||
"""Non-uniform weighting function W(n) for FDR box regression."""
|
||||
ub1 = (abs(up[0]) * abs(reg_scale)).item()
|
||||
ub2 = ub1 * 2
|
||||
step = (ub1 + 1) ** (2 / (reg_max - 2))
|
||||
left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
||||
right = [ (step ** i) - 1 for i in range(1, reg_max // 2)]
|
||||
vals = [-ub2] + left + [0] + right + [ub2]
|
||||
return torch.tensor(vals, dtype=up.dtype, device=up.device)
|
||||
|
||||
|
||||
def distance2bbox(points, distance, reg_scale):
|
||||
"""Decode edge-distances → cxcywh boxes."""
|
||||
rs = abs(reg_scale).to(dtype=points.dtype)
|
||||
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
|
||||
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
|
||||
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
|
||||
y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs)
|
||||
x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
|
||||
return torch.stack([x0, y0, x1_, y1_], -1)
|
||||
|
||||
|
||||
class Integral(nn.Module):
|
||||
"""Sum Pr(n)·W(n) over the distribution bins."""
|
||||
def __init__(self, reg_max=32):
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
|
||||
def forward(self, x, project):
|
||||
shape = x.shape
|
||||
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
|
||||
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
|
||||
return x.reshape(list(shape[:-1]) + [-1])
|
||||
|
||||
|
||||
class LQE(nn.Module):
|
||||
"""Location Quality Estimator — refines class scores using corner distribution."""
|
||||
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.k, self.reg_max = k, reg_max
|
||||
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, scores, pred_corners):
|
||||
B, L, _ = pred_corners.shape
|
||||
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1)
|
||||
topk, _ = prob.topk(self.k, -1)
|
||||
stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1)
|
||||
return scores + self.reg_conf(stat.reshape(B, L, -1))
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.nhead = nhead
|
||||
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.eval_idx + 1)
|
||||
])
|
||||
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
|
||||
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
|
||||
|
||||
def _value_op(self, memory, spatial_shapes):
|
||||
"""Reshape memory to per-level value tensors for deformable attention."""
|
||||
c = self.hidden_dim // self.nhead
|
||||
split = [h * w for h, w in spatial_shapes]
|
||||
val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim]
|
||||
# → [bs, n_head, c, sum_hw]
|
||||
val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw]
|
||||
return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l]
|
||||
|
||||
def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral):
|
||||
val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention
|
||||
|
||||
# reshape to [bs*n_head, c, h_l, w_l]
|
||||
value = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
v = val_split_flat[lvl] # [bs*n_head, c, h*w]
|
||||
value.append(v.reshape(v.shape[0], v.shape[1], h, w))
|
||||
|
||||
ref_pts = F.sigmoid(ref_pts_unact)
|
||||
output = target
|
||||
output_detach = pred_corners_undetach = 0
|
||||
|
||||
dec_bboxes, dec_logits = [], []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4]
|
||||
query_pos = query_pos_head(ref_pts).clamp(-10, 10)
|
||||
output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos)
|
||||
|
||||
if i == 0:
|
||||
ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5)
|
||||
ref_unact = torch.log(ref_unact / (1 - ref_unact))
|
||||
pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact)
|
||||
ref_pts_initial = pre_bboxes.detach()
|
||||
|
||||
pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach
|
||||
inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale)
|
||||
|
||||
if i == self.eval_idx:
|
||||
scores = score_head[i](output)
|
||||
scores = self.lqe_layers[i](scores, pred_corners)
|
||||
dec_bboxes.append(inter_ref_bbox)
|
||||
dec_logits.append(scores)
|
||||
break
|
||||
|
||||
pred_corners_undetach = pred_corners
|
||||
ref_pts = inter_ref_bbox.detach()
|
||||
output_detach = output.detach()
|
||||
|
||||
return torch.stack(dec_bboxes), torch.stack(dec_logits)
|
||||
|
||||
|
||||
class DFINETransformer(nn.Module):
|
||||
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
|
||||
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
|
||||
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert len(feat_strides) == len(feat_channels)
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_queries = num_queries
|
||||
self.num_levels = num_levels
|
||||
self.eps = eps
|
||||
self.eval_spatial_size = eval_spatial_size
|
||||
|
||||
self.feat_strides = list(feat_strides)
|
||||
for i in range(num_levels - len(feat_strides)):
|
||||
self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1))
|
||||
|
||||
# input projection (expects pre-fused weights)
|
||||
self.input_proj = nn.ModuleList()
|
||||
for ch in feat_channels:
|
||||
if ch == hidden_dim:
|
||||
self.input_proj.append(nn.Identity())
|
||||
else:
|
||||
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
|
||||
in_ch = feat_channels[-1]
|
||||
for i in range(num_levels - len(feat_channels)):
|
||||
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
|
||||
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
|
||||
in_ch = hidden_dim
|
||||
|
||||
# FDR parameters (non-trainable placeholders, set from config)
|
||||
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
|
||||
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
|
||||
|
||||
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
|
||||
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
|
||||
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.enc_output = nn.Sequential(OrderedDict([
|
||||
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
|
||||
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
|
||||
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
|
||||
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||
self.dec_score_head = nn.ModuleList(
|
||||
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
|
||||
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
self.dec_bbox_head = nn.ModuleList(
|
||||
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.eval_idx_ + 1)])
|
||||
self.integral = Integral(reg_max)
|
||||
|
||||
if eval_spatial_size:
|
||||
# Register as buffers so checkpoint values override the freshly-computed defaults
|
||||
anchors, valid_mask = self._gen_anchors()
|
||||
self.register_buffer('anchors', anchors)
|
||||
self.register_buffer('valid_mask', valid_mask)
|
||||
|
||||
def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'):
|
||||
if spatial_shapes is None:
|
||||
h0, w0 = self.eval_spatial_size
|
||||
spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides]
|
||||
anchors = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
|
||||
gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype)
|
||||
wh = torch.ones_like(gxy) * grid_size * (2. ** lvl)
|
||||
anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4))
|
||||
anchors = torch.cat(anchors, 1).to(device)
|
||||
valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True)
|
||||
anchors = torch.log(anchors / (1 - anchors))
|
||||
anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf')))
|
||||
return anchors, valid_mask
|
||||
|
||||
def _encoder_input(self, feats: List[torch.Tensor]):
|
||||
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||
for i in range(len(feats), self.num_levels):
|
||||
proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1]))
|
||||
flat, shapes = [], []
|
||||
for f in proj:
|
||||
_, _, h, w = f.shape
|
||||
flat.append(f.flatten(2).permute(0, 2, 1))
|
||||
shapes.append([h, w])
|
||||
return torch.cat(flat, 1), shapes
|
||||
|
||||
def _decoder_input(self, memory: torch.Tensor):
|
||||
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
|
||||
if memory.shape[0] > 1:
|
||||
anchors = anchors.repeat(memory.shape[0], 1, 1)
|
||||
|
||||
mem = valid_mask.to(memory) * memory
|
||||
out_mem = self.enc_output(mem)
|
||||
logits = self.enc_score_head(out_mem)
|
||||
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
|
||||
idx_e = idx.unsqueeze(-1)
|
||||
topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1]))
|
||||
topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1]))
|
||||
topk_ref = self.enc_bbox_head(topk_mem) + topk_anc
|
||||
return topk_mem.detach(), topk_ref.detach()
|
||||
|
||||
def forward(self, feats: List[torch.Tensor]):
|
||||
memory, shapes = self._encoder_input(feats)
|
||||
content, ref = self._decoder_input(memory)
|
||||
out_bboxes, out_logits = self.decoder(
|
||||
content, ref, memory, shapes,
|
||||
self.dec_bbox_head, self.dec_score_head,
|
||||
self.query_pos_head, self.pre_bbox_head, self.integral)
|
||||
return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RTv4(nn.Module):
|
||||
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.operations = operations
|
||||
|
||||
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
|
||||
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
|
||||
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
|
||||
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_queries = num_queries
|
||||
self.load_device = comfy.model_management.get_torch_device()
|
||||
|
||||
def _forward(self, x: torch.Tensor):
|
||||
return self.decoder(self.encoder(self.backbone(x)))
|
||||
|
||||
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
|
||||
logits = outputs['pred_logits']
|
||||
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
|
||||
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
|
||||
scores = F.sigmoid(logits)
|
||||
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
|
||||
labels = idx % self.num_classes
|
||||
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
|
||||
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
|
||||
|
||||
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
|
||||
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
|
||||
return self.postprocess(outputs, orig_size)
|
||||
@@ -141,17 +141,3 @@ def interpret_gathered_like(tensors, gathered):
|
||||
return dest_views
|
||||
|
||||
aimdo_enabled = False
|
||||
|
||||
extra_ram_release_callback = None
|
||||
RAM_CACHE_HEADROOM = 0
|
||||
|
||||
def set_ram_cache_release_state(callback, headroom):
|
||||
global extra_ram_release_callback
|
||||
global RAM_CACHE_HEADROOM
|
||||
extra_ram_release_callback = callback
|
||||
RAM_CACHE_HEADROOM = max(0, int(headroom))
|
||||
|
||||
def extra_ram_release(target):
|
||||
if extra_ram_release_callback is None:
|
||||
return 0
|
||||
return extra_ram_release_callback(target)
|
||||
|
||||
@@ -52,7 +52,6 @@ import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -891,7 +890,7 @@ class Flux(BaseModel):
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs.get("pooled_output", None)
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1958,7 +1957,3 @@ class Kandinsky5Image(Kandinsky5):
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
@@ -698,12 +698,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["audio_model"] = "ace1.5"
|
||||
return dit_config
|
||||
|
||||
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "RT_DETR_v4"
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
@@ -669,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if device is None or shift_model.device == device:
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
@@ -679,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
@@ -708,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
elif device is not None:
|
||||
else:
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
|
||||
@@ -300,6 +300,9 @@ class ModelPatcher:
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import psutil
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -13,11 +12,6 @@ def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
|
||||
#we split the difference and assume half the RAM cache headroom is for us
|
||||
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
|
||||
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
|
||||
13
comfy/sd.py
13
comfy/sd.py
@@ -280,6 +280,9 @@ class CLIP:
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.patcher.get_ram_usage()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@@ -837,6 +840,9 @@ class VAE:
|
||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
@@ -1736,16 +1742,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
"""
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
|
||||
@@ -1734,21 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class RT_DETR_v4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "RT_DETR_v4",
|
||||
}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.RT_DETR_v4(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -1373,7 +1373,6 @@ class NodeInfoV1:
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
essentials_category: str=None
|
||||
has_intermediate_output: bool=None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1497,16 +1496,6 @@ class Schema:
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
essentials_category: str | None = None
|
||||
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
|
||||
has_intermediate_output: bool=False
|
||||
"""Flags this node as having intermediate output that should persist across page refreshes.
|
||||
|
||||
Nodes with this flag behave like output nodes (their UI results are cached and resent
|
||||
to the frontend) but do NOT automatically get added to the execution list. This means
|
||||
they will only execute if they are on the dependency path of a real output node.
|
||||
|
||||
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
|
||||
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
|
||||
"""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@@ -1606,7 +1595,6 @@ class Schema:
|
||||
category=self.category,
|
||||
description=self.description,
|
||||
output_node=self.is_output_node,
|
||||
has_intermediate_output=self.has_intermediate_output,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
dev_only=self.is_dev_only,
|
||||
@@ -1898,14 +1886,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._OUTPUT_NODE
|
||||
|
||||
_HAS_INTERMEDIATE_OUTPUT = None
|
||||
@final
|
||||
@classproperty
|
||||
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._HAS_INTERMEDIATE_OUTPUT
|
||||
|
||||
_INPUT_IS_LIST = None
|
||||
@final
|
||||
@classproperty
|
||||
@@ -1998,8 +1978,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._API_NODE = schema.is_api_node
|
||||
if cls._OUTPUT_NODE is None:
|
||||
cls._OUTPUT_NODE = schema.is_output_node
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
|
||||
if cls._INPUT_IS_LIST is None:
|
||||
cls._INPUT_IS_LIST = schema.is_input_list
|
||||
if cls._NOT_IDEMPOTENT is None:
|
||||
|
||||
@@ -201,16 +201,6 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
|
||||
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
if not thought:
|
||||
# No images generated --> extract text response for a meaningful error
|
||||
model_message = get_text_from_response(response).strip()
|
||||
if model_message:
|
||||
raise ValueError(f"Gemini did not generate an image. Model response: {model_message}")
|
||||
raise ValueError(
|
||||
"Gemini did not generate an image. "
|
||||
"Try rephrasing your prompt or changing the response modality to 'IMAGE+TEXT' "
|
||||
"to see the model's reasoning."
|
||||
)
|
||||
return torch.zeros((1, 1024, 1024, 4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
@@ -474,10 +475,6 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(node_id)
|
||||
return await self._set_immediate(node_id, value)
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
BasicCache.set_local(self, node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
await super()._ensure_subcache(node_id, children_ids)
|
||||
@@ -492,10 +489,15 @@ class LRUCache(BasicCache):
|
||||
return self
|
||||
|
||||
|
||||
#Small baseline weight used when a cache entry has no measurable CPU tensors.
|
||||
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
|
||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||
|
||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
|
||||
RAM_CACHE_HYSTERESIS = 1.1
|
||||
|
||||
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||
|
||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||
|
||||
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||
#in constantly changing setups.
|
||||
@@ -519,17 +521,19 @@ class RAMPressureCache(LRUCache):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return await super().get(node_id)
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
super().set_local(node_id, value)
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
def ram_release(self, target):
|
||||
if psutil.virtual_memory().available >= target:
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, cache_entry in self.cache.items():
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
@@ -538,20 +542,22 @@ class RAMPressureCache(LRUCache):
|
||||
if outputs is None:
|
||||
return
|
||||
for output in outputs:
|
||||
if isinstance(output, (list, tuple)):
|
||||
if isinstance(output, list):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
ram_usage += output.numel() * output.element_size()
|
||||
scan_list_for_ram_usage(cache_entry.outputs)
|
||||
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||
#be high value intermediates
|
||||
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||
elif hasattr(output, "get_ram_usage"):
|
||||
ram_usage += output.get_ram_usage()
|
||||
scan_list_for_ram_usage(outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
|
||||
while psutil.virtual_memory().available < target and clean_list:
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
gc.collect()
|
||||
|
||||
@@ -1,85 +1,67 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@@ -102,7 +84,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@@ -126,14 +108,21 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@@ -159,163 +148,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@@ -327,131 +161,111 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if not self._display:
|
||||
raise RuntimeError("eglGetDisplay() returned no display")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
if not EGL.eglInitialize(self._display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
err = EGL.eglGetError()
|
||||
self._display = None
|
||||
raise RuntimeError(f"eglInitialize() failed (EGL error: 0x{err:04X})")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context)
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@@ -459,8 +273,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@@ -484,8 +300,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@@ -530,9 +348,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@@ -558,9 +373,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@@ -723,13 +538,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@@ -750,18 +565,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
@@ -813,7 +628,6 @@ class GLSLShader(io.ComfyNode):
|
||||
"u_resolution (vec2) is always available."
|
||||
),
|
||||
is_experimental=True,
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
|
||||
@@ -59,7 +59,6 @@ class ImageCropV2(IO.ComfyNode):
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
|
||||
@@ -30,7 +30,6 @@ class PainterNode(io.ComfyNode):
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
|
||||
class RTDETR_detect(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RTDETR_detect",
|
||||
display_name="RT-DETR Detect",
|
||||
category="detection/",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||
inputs=[
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Image.Input("image", display_name="image"),
|
||||
io.Float.Input("threshold", display_name="threshold", default=0.5),
|
||||
io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."),
|
||||
io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."),
|
||||
],
|
||||
outputs=[
|
||||
io.BoundingBox.Output("bboxes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
||||
B, H, W, C = image.shape
|
||||
|
||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
|
||||
|
||||
all_bbox_dicts = []
|
||||
|
||||
for det in results:
|
||||
keep = det['scores'] > threshold
|
||||
boxes = det['boxes'][keep].cpu()
|
||||
labels = det['labels'][keep].cpu()
|
||||
scores = det['scores'][keep].cpu()
|
||||
|
||||
bbox_dicts = [
|
||||
{
|
||||
"x": float(box[0]),
|
||||
"y": float(box[1]),
|
||||
"width": float(box[2] - box[0]),
|
||||
"height": float(box[3] - box[1]),
|
||||
"label": COCO_CLASSES[int(label)],
|
||||
"score": float(score)
|
||||
}
|
||||
for box, label, score in zip(boxes, labels, scores)
|
||||
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
|
||||
]
|
||||
bbox_dicts.sort(key=lambda d: d["score"], reverse=True)
|
||||
all_bbox_dicts.append(bbox_dicts[:max_detections])
|
||||
|
||||
return io.NodeOutput(all_bbox_dicts)
|
||||
|
||||
|
||||
class DrawBBoxes(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DrawBBoxes",
|
||||
display_name="Draw BBoxes",
|
||||
category="detection/",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
||||
inputs=[
|
||||
io.Image.Input("image", optional=True),
|
||||
io.BoundingBox.Input("bboxes", force_input=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("out_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, bboxes, image=None) -> io.NodeOutput:
|
||||
# Normalise to list[list[dict]], then fit to batch size B.
|
||||
B = image.shape[0] if image is not None else 1
|
||||
if isinstance(bboxes, dict):
|
||||
bboxes = [[bboxes]]
|
||||
elif not isinstance(bboxes, list) or not bboxes:
|
||||
bboxes = [[]]
|
||||
elif isinstance(bboxes[0], dict):
|
||||
bboxes = [bboxes] # flat list → same detections for every image
|
||||
|
||||
if len(bboxes) == 1:
|
||||
bboxes = bboxes * B
|
||||
bboxes = (bboxes + [[]] * B)[:B]
|
||||
|
||||
if image is None:
|
||||
B = len(bboxes)
|
||||
max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640)
|
||||
max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640)
|
||||
image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32)
|
||||
|
||||
all_out_images = []
|
||||
for i in range(B):
|
||||
detections = bboxes[i]
|
||||
if detections:
|
||||
boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections])
|
||||
labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections]
|
||||
scores = torch.tensor([d.get("score", 1.0) for d in detections])
|
||||
else:
|
||||
boxes = torch.zeros((0, 4))
|
||||
labels = []
|
||||
scores = torch.zeros((0,))
|
||||
|
||||
pil_image = image[i].movedim(-1, 0)
|
||||
img = ToPILImage()(pil_image)
|
||||
if detections:
|
||||
img = cls.draw_detections(img, boxes, labels, scores)
|
||||
all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1))
|
||||
|
||||
out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput(out_images)
|
||||
|
||||
@classmethod
|
||||
def draw_detections(cls, img, boxes, labels, scores):
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
font = ImageFont.truetype('arial.ttf', 16)
|
||||
except Exception:
|
||||
font = ImageFont.load_default()
|
||||
colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128),
|
||||
(0,255,255),(255,20,147),(100,149,237)]
|
||||
for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()):
|
||||
x1, y1, x2, y2 = box.tolist()
|
||||
color_idx = COCO_CLASSES.index(label) if label is not None else 0
|
||||
c = colors[color_idx % len(colors)]
|
||||
draw.rectangle([x1, y1, x2, y2], outline=c, width=3)
|
||||
if label is not None:
|
||||
draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font)
|
||||
return img
|
||||
|
||||
|
||||
class RTDETRExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
RTDETR_detect,
|
||||
DrawBBoxes,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> RTDETRExtension:
|
||||
return RTDETRExtension()
|
||||
@@ -661,7 +661,6 @@ class CropByBBoxes(io.ComfyNode):
|
||||
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
|
||||
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
|
||||
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
|
||||
io.Combo.Input("keep_aspect", options=["stretch", "pad"], default="stretch", tooltip="Whether to stretch the crop to fit the output size, or pad with black pixels to preserve aspect ratio."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(tooltip="All crops stacked into a single image batch."),
|
||||
@@ -669,7 +668,7 @@ class CropByBBoxes(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspect="stretch") -> io.NodeOutput:
|
||||
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
|
||||
total_frames = image.shape[0]
|
||||
img_h = image.shape[1]
|
||||
img_w = image.shape[2]
|
||||
@@ -717,19 +716,7 @@ class CropByBBoxes(io.ComfyNode):
|
||||
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
|
||||
|
||||
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
|
||||
|
||||
if keep_aspect == "pad":
|
||||
crop_h, crop_w = y2 - y1, x2 - x1
|
||||
scale = min(output_width / crop_w, output_height / crop_h)
|
||||
scaled_w = int(round(crop_w * scale))
|
||||
scaled_h = int(round(crop_h * scale))
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||
pad_left = (output_width - scaled_w) // 2
|
||||
pad_top = (output_height - scaled_h) // 2
|
||||
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
|
||||
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||
else: # "stretch"
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
crops.append(resized)
|
||||
|
||||
if not crops:
|
||||
|
||||
38
execution.py
38
execution.py
@@ -411,19 +411,6 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def _is_intermediate_output(dynprompt, node_id):
|
||||
class_type = dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||
if server.client_id is None:
|
||||
return
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[node_id] = cached.ui
|
||||
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
@@ -434,7 +421,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
cached = await caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
||||
if server.client_id is not None:
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[unique_id] = cached.ui
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, cached)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@@ -724,9 +715,6 @@ class PromptExecutor:
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
|
||||
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
|
||||
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
@@ -776,22 +764,9 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
|
||||
if self.cache_type == CacheType.RAM_PRESSURE:
|
||||
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
||||
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
# Send cached UI for intermediate output nodes that weren't executed
|
||||
for node_id in dynamic_prompt.all_node_ids():
|
||||
if node_id in executed:
|
||||
continue
|
||||
if not _is_intermediate_output(dynamic_prompt, node_id):
|
||||
continue
|
||||
cached = await self.caches.outputs.get(node_id)
|
||||
if cached is not None:
|
||||
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
@@ -807,7 +782,6 @@ class PromptExecutor:
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
|
||||
8
main.py
8
main.py
@@ -275,19 +275,15 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_ram = args.cache_ram
|
||||
if cache_ram < 0:
|
||||
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
|
||||
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif cache_ram > 0:
|
||||
elif args.cache_ram > 0:
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.NONE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
||||
1
nodes.py
1
nodes.py
@@ -2457,7 +2457,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
"nodes_rtdetr.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.39
|
||||
comfyui-workflow-templates==0.9.38
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
|
||||
@@ -709,11 +709,6 @@ class PromptServer():
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'HAS_INTERMEDIATE_OUTPUT') and obj_class.HAS_INTERMEDIATE_OUTPUT == True:
|
||||
info['has_intermediate_output'] = True
|
||||
else:
|
||||
info['has_intermediate_output'] = False
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
|
||||
|
||||
Reference in New Issue
Block a user