Compare commits

...

36 Commits

Author SHA1 Message Date
comfyanonymous
e18f53cca9 ComfyUI version 0.3.43 2025-06-27 17:22:02 -04:00
comfyanonymous
c36be0ea09 Fix memory estimation bug with kontext. (#8709) 2025-06-27 17:21:12 -04:00
comfyanonymous
9093301a49 Don't add tiny bit of random noise when VAE encoding. (#8705)
Shouldn't change outputs but might make things a tiny bit more
deterministic.
2025-06-27 14:14:56 -04:00
comfyanonymous
bd951a714f Add Flux Kontext and Omnigen 2 models to readme. (#8682) 2025-06-26 12:26:29 -04:00
comfyanonymous
6493709d6a ComfyUI version 0.3.42 2025-06-26 11:47:07 -04:00
filtered
b976f934ae Update frontend to 1.23.4 (#8681) 2025-06-26 11:44:12 -04:00
comfyanonymous
7d8cf4cacc Update requirements.txt (#8680) 2025-06-26 11:39:40 -04:00
filtered
68f4496b8e Update frontend to 1.23.3 (#8678) 2025-06-26 11:29:03 -04:00
comfyanonymous
ef5266b1c1 Support Flux Kontext Dev model. (#8679) 2025-06-26 11:28:41 -04:00
comfyanonymous
a96e65df18 Disable omnigen2 fp16 on older pytorch versions. (#8672) 2025-06-26 03:39:09 -04:00
comfyanonymous
93a49a45de Bump minimum transformers version. (#8671) 2025-06-26 02:33:02 -04:00
comfyanonymous
ec70ed6aea Omnigen2 model implementation. (#8669) 2025-06-25 19:35:57 -04:00
comfyanonymous
7a13f74220 unet -> diffusion model (#8659) 2025-06-25 04:52:34 -04:00
chaObserv
8042eb20c6 Singlestep DPM++ SDE for RF (#8627)
Refactor the algorithm, and apply alpha scaling.
2025-06-24 14:59:09 -04:00
comfyanonymous
bd9f166c12 Cosmos predict2 model merging nodes. (#8647) 2025-06-24 05:17:16 -04:00
comfyanonymous
dd94416db2 Indicate that directml is not recommended in the README. (#8644) 2025-06-23 14:04:49 -04:00
comfyanonymous
ae0e7c4dff Resize and pad image node. (#8636) 2025-06-22 17:59:31 -04:00
comfyanonymous
78f79266a9 Allow padding in ImageStitch node to be white. (#8631) 2025-06-22 00:19:41 -04:00
comfyanonymous
1883e70b43 Fix exception when using a noise mask with cosmos predict2. (#8621)
* Fix exception when using a noise mask with cosmos predict2.

* Fix ruff.
2025-06-21 03:30:39 -04:00
Lucas - BLOCK33
31ca603ccb Improve the log time function for 10 minute + renders (#6207)
* modified:   main.py

* Update main.py
2025-06-20 23:04:55 -04:00
comfyanonymous
f7fb193712 Small flux optimization. (#8611) 2025-06-20 05:37:32 -04:00
comfyanonymous
7e9267fa77 Make flux controlnet work with sd3 text enc. (#8599) 2025-06-19 18:50:05 -04:00
comfyanonymous
91d40086db Fix pytorch warning. (#8593) 2025-06-19 11:04:52 -04:00
coderfromthenorth93
5b12b55e32 Add new fields to the config types (#8507) 2025-06-18 15:12:29 -04:00
comfyanonymous
e9e9a031a8 Show a better error when the workflow OOMs. (#8574) 2025-06-18 06:55:21 -04:00
filtered
d7430c529a Update frontend to 1.22.2 (#8567) 2025-06-17 18:58:28 -04:00
ComfyUI Wiki
cd88f709ab Update template version (#8563) 2025-06-17 04:11:59 -07:00
comfyanonymous
4459a17e82 Add Cosmos Predict2 to README. (#8562) 2025-06-17 05:18:01 -04:00
comfyanonymous
483b3e62e0 ComfyUI version v0.3.41 2025-06-16 23:34:46 -04:00
chaObserv
8e81c507d2 Multistep DPM++ SDE samplers for RF (#8541)
Include alpha in sampling and minor refactoring
2025-06-16 14:47:10 -04:00
comfyanonymous
e1c6dc720e Allow setting min_length with tokenizer_data. (#8547) 2025-06-16 13:43:52 -04:00
comfyanonymous
7ea79ebb9d Add correct eps to ltxv rmsnorm. (#8542) 2025-06-15 12:21:25 -04:00
comfyanonymous
ae75a084df SaveLora now saves in the same filename format as all the other nodes. (#8538) 2025-06-15 03:44:59 -04:00
comfyanonymous
d6a2137fc3 Support Cosmos predict2 image to video models. (#8535)
Use the CosmosPredict2ImageToVideoLatent node.
2025-06-14 21:37:07 -04:00
chaObserv
53e8d8193c Generalize SEEDS samplers (#8529)
Restore VP algorithm for RF and refactor noise_coeffs and half-logSNR calculations
2025-06-14 16:58:16 -04:00
comfyanonymous
29596bd53f Small cosmos attention code refactor. (#8530) 2025-06-14 05:02:05 -04:00
35 changed files with 152886 additions and 161 deletions

View File

@@ -65,12 +65,16 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@@ -272,6 +276,8 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs

View File

@@ -1,4 +1,5 @@
import math
from functools import partial
from scipy import integrate
import torch
@@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler:
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def sigma_to_half_log_snr(sigma, model_sampling):
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# log((1 - t) / t) = log((1 - sigma) / sigma)
return sigma.logit().neg()
return sigma.log().neg()
def half_log_snr_to_sigma(half_log_snr, model_sampling):
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# 1 / (1 + exp(half_log_snr))
return half_log_snr.neg().sigmoid()
return half_log_snr.neg().exp()
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
"""Adjust the first sigma to avoid invalid logSNR."""
if len(sigmas) <= 1:
return sigmas
if isinstance(model_sampling, comfy.model_sampling.CONST):
if sigmas[0] >= 1:
sigmas = sigmas.clone()
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
return sigmas
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -682,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@@ -693,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
# Denoising step
x = denoised
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
alpha_s = sigmas[i] * lambda_s.exp()
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
lambda_s_1_ = sd.log().neg()
h_ = lambda_s_1_ - lambda_s
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
if eta > 0 and s_noise > 0:
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
lambda_t_ = sd.log().neg()
h_ = lambda_t_ - lambda_s
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
return x
@@ -753,6 +793,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
@@ -768,9 +809,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
old_denoised = None
h_last = None
h = None
h, h_last = None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@@ -781,26 +825,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -814,6 +861,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@@ -825,13 +876,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
# Denoising step
x = denoised
else:
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if h_2 is not None:
# DPM-Solver++(3M) SDE
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
@@ -840,20 +894,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + phi_2 * d1 - phi_3 * d2
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
elif h_1 is not None:
# DPM-Solver++(2M) SDE
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + phi_2 * d
x = x + (alpha_t * phi_2) * d
if eta:
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
@@ -863,6 +919,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@@ -872,6 +929,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
@@ -1449,12 +1507,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
'''
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
Arxiv: https://arxiv.org/abs/2305.14267
'''
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1462,6 +1520,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
@@ -1469,80 +1532,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
s = t + r * h
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s = s.neg().exp()
sigma_s_1 = sigma_fn(lambda_s_1)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
'''
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
Arxiv: https://arxiv.org/abs/2305.14267
'''
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
h = t_next - t
h_eta = h * (eta + 1)
s_1 = t + r_1 * h
s_2 = t + r_2 * h
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x

View File

@@ -70,11 +70,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)"
)
return result_B_S_HD
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
class Attention(nn.Module):

View File

@@ -123,6 +123,8 @@ class ControlNetFlux(Flux):
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
else:
y = y[:, :self.params.vec_in_dim]
# running on sequences img
img = self.img_in(img)

View File

@@ -118,7 +118,7 @@ class Modulation(nn.Module):
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
if modulation_dims is None:
if m_add is not None:
return tensor * m_mult + m_add
return torch.addcmul(m_add, tensor, m_mult)
else:
return tensor * m_mult
else:

View File

@@ -195,20 +195,50 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
for ref in ref_latents:
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]

View File

@@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

View File

@@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
def __init__(self, sample: bool = False):
super().__init__()
self.sample = sample
@@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
return z, None
class AbstractAutoencoder(torch.nn.Module):

View File

@@ -31,7 +31,7 @@ def dynamic_slice(
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
return x[slicing]
class AttnChunk(NamedTuple):

View File

@@ -0,0 +1,469 @@
# Original code: https://github.com/VectorSpaceLab/OmniGen2
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management
import comfy.ldm.common_dit
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape).to(dtype=x.dtype)
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.act = nn.SiLU()
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class LuminaRMSNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class LuminaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x) * (1 + emb)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
class LuminaFeedForward(nn.Module):
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h1, h2 = self.linear_1(x), self.linear_3(x)
return self.linear_2(swiglu(h1, h2))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
self.caption_embedder = nn.Sequential(
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed
class Attention(nn.Module):
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = dim_head
self.scale = dim_head ** -0.5
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = query.view(batch_size, -1, self.heads, self.dim_head)
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.kv_heads < self.heads:
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
class OmniGen2TransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.attn = Attention(
query_dim=dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
dtype=dtype, device=device, operations=operations,
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
dtype=dtype, device=device, operations=operations
)
if modulation:
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
else:
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class OmniGen2RotaryPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.patch_size = patch_size
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
p = self.patch_size
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
max_seq_len = max(seq_lengths)
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
pe_shift = cap_seq_len
pe_shift_len = cap_seq_len
if ref_img_sizes[i] is not None:
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
H, W = ref_img_size
ref_H_tokens, ref_W_tokens = H // p, W // p
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
pe_shift += max(ref_H_tokens, ref_W_tokens)
pe_shift_len += ref_img_len
H, W = img_sizes[i]
H_tokens, W_tokens = H // p, W // p
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
cap_freqs_cis_shape = list(freqs_cis.shape)
cap_freqs_cis_shape[1] = encoder_seq_len
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
ref_img_freqs_cis_shape = list(freqs_cis.shape)
ref_img_freqs_cis_shape[1] = max_ref_img_len
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
class OmniGen2Transformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
text_feat_dim: int = 1024,
timestep_scale: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=text_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
) for _ in range(num_refiner_layers)
])
self.layers = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size, num_attention_heads, num_kv_heads,
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
) for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
batch_size = len(hidden_states)
p = self.patch_size
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
if ref_image_hidden_states is not None:
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
else:
ref_img_sizes = [None for _ in range(batch_size)]
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
flat_ref_img_hidden_states = None
if ref_image_hidden_states is not None:
imgs = []
for ref_img in ref_image_hidden_states:
B, C, H, W = ref_img.size()
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
imgs.append(ref_img)
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
img = hidden_states
B, C, H, W = img.size()
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
return (
flat_hidden_states, flat_ref_img_hidden_states,
None, None,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if ref_image_hidden_states is not None:
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
for i in range(batch_size):
shift = 0
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
return -output

View File

@@ -41,6 +41,7 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.model_management
import comfy.patcher_extension
@@ -815,6 +816,7 @@ class PixArt(BaseModel):
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
self.memory_usage_factor_conds = ("ref_latents",)
def concat_cond(self, **kwargs):
try:
@@ -875,8 +877,23 @@ class Flux(BaseModel):
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@@ -1014,9 +1031,32 @@ class CosmosPredict2(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
if denoise_mask.ndim <= 4:
return timestep
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
return out
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
sigma_noise_augmentation = 0 #TODO
if sigma_noise_augmentation != 0:
latent_image = latent_image + noise
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
@@ -1207,3 +1247,33 @@ class ACEStep(BaseModel):
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out

View File

@@ -441,11 +441,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17: # img to video
if dit_config["model_channels"] == 2048:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["model_channels"] == 5120:
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
@@ -454,6 +459,26 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
dit_config["axes_dim_rope"] = [40, 40, 40]
dit_config["axes_lens"] = [1024, 1664, 1664]
dit_config["ffn_dim_multiplier"] = None
dit_config["hidden_size"] = 2520
dit_config["in_channels"] = 16
dit_config["multiple_of"] = 256
dit_config["norm_eps"] = 1e-05
dit_config["num_attention_heads"] = 21
dit_config["num_kv_heads"] = 7
dit_config["num_layers"] = 32
dit_config["num_refiner_layers"] = 2
dit_config["out_channels"] = None
dit_config["patch_size"] = 2
dit_config["text_feat_dim"] = 2048
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@@ -1290,6 +1290,13 @@ def supports_fp8_compute(device=None):
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
return False
return True
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -44,6 +44,7 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.model_patcher
import comfy.lora
@@ -754,6 +755,7 @@ class CLIPType(Enum):
HIDREAM = 14
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -773,6 +775,7 @@ class TEModel(Enum):
LLAMA3_8 = 7
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -793,6 +796,8 @@ def detect_te_model(sd):
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
return TEModel.QWEN25_3B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -894,6 +899,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -1160,7 +1168,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in unet: {}".format(left_over))
logging.info("left over keys in diffusion model: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
@@ -1168,7 +1176,7 @@ def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model

View File

@@ -462,7 +462,7 @@ class SDTokenizer:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
@@ -482,7 +482,8 @@ class SDTokenizer:
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if has_end_token:
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token

View File

@@ -18,6 +18,7 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
from . import supported_models_base
from . import latent_formats
@@ -1181,6 +1182,41 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
class Omnigen2(supported_models_base.BASE):
unet_config = {
"image_model": "omnigen2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Omnigen2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
models += [SVD_img2vid]

View File

@@ -24,6 +24,24 @@ class Llama2Config:
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 11008
num_hidden_layers: int = 36
num_attention_heads: int = 16
num_key_value_heads: int = 2
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
@@ -40,6 +58,7 @@ class Gemma2_2B_Config:
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -98,9 +117,9 @@ class Attention(nn.Module):
self.inner_size = self.num_heads * self.head_dim
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
@@ -320,6 +339,14 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_3BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):

View File

@@ -0,0 +1,44 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Omnigen2Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Omnigen2TEModel_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,241 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<|img|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151666": {
"content": "<|endofimg|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151667": {
"content": "<|meta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151668": {
"content": "<|endofmeta|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"bos_token": null,
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"errors": "replace",
"extra_special_tokens": {},
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"processor_class": "Qwen2_5_VLProcessor",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

File diff suppressed because one or more lines are too long

View File

@@ -11,6 +11,43 @@ from comfy_config.types import (
PyProjectSettings
)
def validate_and_extract_os_classifiers(classifiers: list) -> list:
os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")]
if not os_classifiers:
return []
os_values = [c[len("Operating System :: ") :] for c in os_classifiers]
valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"}
for os_value in os_values:
if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes):
return []
return os_values
def validate_and_extract_accelerator_classifiers(classifiers: list) -> list:
accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")]
if not accelerator_classifiers:
return []
accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers]
valid_accelerators = {
"GPU :: NVIDIA CUDA",
"GPU :: AMD ROCm",
"GPU :: Intel Arc",
"NPU :: Huawei Ascend",
"GPU :: Apple Metal",
}
for accelerator_value in accelerator_values:
if accelerator_value not in valid_accelerators:
return []
return accelerator_values
"""
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
@@ -78,6 +115,24 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]:
tool_data = raw_settings.tool
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
dependencies = project_data.get("dependencies", [])
supported_comfyui_frontend_version = ""
for dep in dependencies:
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package")
break
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
classifiers = project_data.get('classifiers', [])
supported_os = validate_and_extract_os_classifiers(classifiers)
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
project_data['supported_os'] = supported_os
project_data['supported_accelerators'] = supported_accelerators
project_data['supported_comfyui_frontend_version'] = supported_comfyui_frontend_version
project_data['supported_comfyui_version'] = supported_comfyui_version
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)

View File

@@ -51,7 +51,7 @@ class ComfyConfig(BaseModel):
models: List[Model] = Field(default_factory=list, alias="Models")
includes: List[str] = Field(default_factory=list)
web: Optional[str] = None
banner_url: str = ""
class License(BaseModel):
file: str = ""
@@ -66,6 +66,10 @@ class ProjectConfig(BaseModel):
dependencies: List[str] = Field(default_factory=list)
license: License = Field(default_factory=License)
urls: URLs = Field(default_factory=URLs)
supported_os: List[str] = Field(default_factory=list)
supported_accelerators: List[str] = Field(default_factory=list)
supported_comfyui_version: str = ""
supported_comfyui_frontend_version: str = ""
@field_validator('license', mode='before')
@classmethod

View File

@@ -2,6 +2,7 @@ import nodes
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
class EmptyCosmosLatentVideo:
@@ -75,8 +76,53 @@ class CosmosImageToVideoLatent:
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
class CosmosPredict2ImageToVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is None and end_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
if start_image is not None:
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
latent[:, :, :latent_temp.shape[-3]] = latent_temp
mask[:, :, :latent_temp.shape[-3]] *= 0.0
if end_image is not None:
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
out_latent = {}
latent_format = comfy.latent_formats.Wan21()
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
}

View File

@@ -0,0 +1,26 @@
import node_helpers
class ReferenceLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
},
"optional": {"latent": ("LATENT", ),}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/edit_models"
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
def append(self, conditioning, latent=None):
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
return (conditioning, )
NODE_CLASS_MAPPINGS = {
"ReferenceLatent": ReferenceLatent,
}

View File

@@ -1,4 +1,5 @@
import node_helpers
import comfy.utils
class CLIPTextEncodeFlux:
@classmethod
@@ -56,8 +57,52 @@ class FluxDisableGuidance:
return (c, )
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
class FluxKontextImageScale:
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE", ),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "scale"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
}

View File

@@ -304,10 +304,23 @@ Optional spacing can be added between images.
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
).movedim(1, -1)
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
# When not matching sizes, pad to align non-concat dimensions
if not match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
pad_value = 0.0
if not isinstance(color_val, tuple):
pad_value = color_val
if direction in ["left", "right"]:
# For horizontal concat, pad heights to match
@@ -316,11 +329,11 @@ Optional spacing can be added between images.
if h1 < target_h:
pad_h = target_h - h1
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
if h2 < target_h:
pad_h = target_h - h2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
else: # up, down
# For vertical concat, pad widths to match
if w1 != w2:
@@ -328,11 +341,11 @@ Optional spacing can be added between images.
if w1 < target_w:
pad_w = target_w - w1
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
if w2 < target_w:
pad_w = target_w - w2
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
# Ensure same number of channels
if image1.shape[-1] != image2.shape[-1]:
@@ -366,15 +379,6 @@ Optional spacing can be added between images.
if spacing_width > 0:
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
if direction in ["left", "right"]:
spacing_shape = (
image1.shape[0],
@@ -410,6 +414,62 @@ Optional spacing can be added between images.
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
class ResizeAndPadImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"target_width": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"target_height": ("INT", {
"default": 512,
"min": 1,
"max": MAX_RESOLUTION,
"step": 1
}),
"padding_color": (["white", "black"],),
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "resize_and_pad"
CATEGORY = "image/transform"
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
batch_size, orig_height, orig_width, channels = image.shape
scale_w = target_width / orig_width
scale_h = target_height / orig_height
scale = min(scale_w, scale_h)
new_width = int(orig_width * scale)
new_height = int(orig_height * scale)
image_permuted = image.permute(0, 3, 1, 2)
resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled")
pad_value = 0.0 if padding_color == "black" else 1.0
padded = torch.full(
(batch_size, channels, target_height, target_width),
pad_value,
dtype=image.dtype,
device=image.device
)
y_offset = (target_height - new_height) // 2
x_offset = (target_width - new_width) // 2
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
output = padded.permute(0, 2, 3, 1)
return (output,)
class SaveSVGNode:
"""
@@ -532,5 +592,6 @@ NODE_CLASS_MAPPINGS = {
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
"ResizeAndPadImage": ResizeAndPadImage,
"GetImageSize": GetImageSize,
}

View File

@@ -268,6 +268,52 @@ class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
for i in range(28):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embedder."] = argument
arg_dict["x_embedder."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t_embedding_norm."] = argument
for i in range(36):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -281,4 +327,6 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
}

View File

@@ -552,6 +552,9 @@ class LoraModelLoader:
class SaveLoRA:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {
@@ -565,7 +568,7 @@ class SaveLoRA:
"prefix": (
"STRING",
{
"default": "trained_lora",
"default": "loras/ComfyUI_trained_lora",
"tooltip": "The prefix to use for the saved LoRA file.",
},
),
@@ -588,12 +591,13 @@ class SaveLoRA:
OUTPUT_NODE = True
def save(self, lora, prefix, steps=None):
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
if steps is None:
output_file = f"models/loras/{prefix}_{date}_lora.safetensors"
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors"
safetensors.torch.save_file(lora, output_file)
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
return {}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.40"
__version__ = "0.3.43"

View File

@@ -429,17 +429,20 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
logging.error(f"!!! Exception during processing !!! {ex}")
logging.error(traceback.format_exc())
tips = ""
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
error_details = {
"node_id": real_node_id,
"exception_message": str(ex),
"exception_message": "{}\n{}".format(ex, tips),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted
}
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (ExecutionResult.FAILURE, error_details, ex)

View File

@@ -185,7 +185,13 @@ def prompt_worker(q, server_instance):
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
# Log Time in a more readable way after 10 minutes
if execution_time > 600:
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
logging.info(f"Prompt executed in {execution_time}")
else:
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)

View File

@@ -920,7 +920,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -930,7 +930,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@@ -2279,6 +2279,7 @@ def init_builtin_extra_nodes():
"nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.40"
version = "0.3.43"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,13 +1,13 @@
comfyui-frontend-package==1.21.7
comfyui-workflow-templates==0.1.28
comfyui-embedded-docs==0.2.2
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.30
comfyui-embedded-docs==0.2.3
torch
torchsde
torchvision
torchaudio
numpy>=1.25.0
einops
transformers>=4.28.1
transformers>=4.37.2
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2