mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-01 03:41:31 +00:00
Compare commits
27 Commits
v3/nodes_l
...
luke-mino-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad5604fb0b | ||
|
|
7f00f48c96 | ||
|
|
42edf71854 | ||
|
|
9c2a423aec | ||
|
|
731a95eb13 | ||
|
|
781d451355 | ||
|
|
1f1894608d | ||
|
|
657bf5a55e | ||
|
|
14183b3c21 | ||
|
|
373b2a735e | ||
|
|
a130ccc942 | ||
|
|
a8371ef1bc | ||
|
|
d653b86bd7 | ||
|
|
f26384f371 | ||
|
|
bfdb78da05 | ||
|
|
e59fbc101d | ||
|
|
defd97d8b8 | ||
|
|
a611444b82 | ||
|
|
7a54eb33ca | ||
|
|
c3cc3ba24f | ||
|
|
09730315d2 | ||
|
|
05ed9e774a | ||
|
|
0fff4c980f | ||
|
|
b98e727582 | ||
|
|
315aa8c3bf | ||
|
|
d621657143 | ||
|
|
d280ae140f |
@@ -3,12 +3,8 @@ import os
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import IO, Any, Callable, Iterator
|
from typing import IO, Any, Callable, Iterator
|
||||||
import logging
|
|
||||||
|
|
||||||
try:
|
from blake3 import blake3
|
||||||
from blake3 import blake3
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
logging.warning("WARNING: blake3 package not installed")
|
|
||||||
|
|
||||||
DEFAULT_CHUNK = 8 * 1024 * 1024
|
DEFAULT_CHUNK = 8 * 1024 * 1024
|
||||||
|
|
||||||
|
|||||||
@@ -223,19 +223,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
del txt_k, img_k
|
del txt_k, img_k
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
del txt_v, img_v
|
del txt_v, img_v
|
||||||
|
|
||||||
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
|
|
||||||
if "attn1_patch" in transformer_patches:
|
|
||||||
patch = transformer_patches["attn1_patch"]
|
|
||||||
for p in patch:
|
|
||||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
|
||||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
|
||||||
|
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
if "attn1_output_patch" in transformer_patches:
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||||
patch = transformer_patches["attn1_output_patch"]
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
attn = p(attn, extra_options)
|
attn = p(attn, extra_options)
|
||||||
@@ -328,12 +321,6 @@ class SingleStreamBlock(nn.Module):
|
|||||||
del qkv
|
del qkv
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
if "attn1_patch" in transformer_patches:
|
|
||||||
patch = transformer_patches["attn1_patch"]
|
|
||||||
for p in patch:
|
|
||||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
|
||||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|||||||
@@ -31,8 +31,6 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
|
|
||||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
|
|
||||||
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
|
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
if "post_input" in patches:
|
if "post_input" in patches:
|
||||||
for p in patches["post_input"]:
|
for p in patches["post_input"]:
|
||||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img_ids = out["img_ids"]
|
img_ids = out["img_ids"]
|
||||||
|
|||||||
@@ -372,8 +372,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
if first_op_done == False:
|
if first_op_done == False:
|
||||||
model_management.soft_empty_cache(True)
|
model_management.soft_empty_cache(True)
|
||||||
if cleared_cache == False:
|
if cleared_cache == False:
|
||||||
|
|||||||
@@ -258,8 +258,7 @@ def slice_attention(q, k, v):
|
|||||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
del s2
|
del s2
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
model_management.soft_empty_cache(True)
|
model_management.soft_empty_cache(True)
|
||||||
steps *= 2
|
steps *= 2
|
||||||
if steps > 128:
|
if steps > 128:
|
||||||
@@ -315,8 +314,7 @@ def pytorch_attention(q, k, v):
|
|||||||
try:
|
try:
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except Exception as e:
|
except model_management.OOM_EXCEPTION:
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
oom_fallback = True
|
oom_fallback = True
|
||||||
if oom_fallback:
|
if oom_fallback:
|
||||||
|
|||||||
@@ -169,8 +169,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
try:
|
try:
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
except Exception as e:
|
except model_management.OOM_EXCEPTION:
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||||
torch.exp(attn_scores, out=attn_scores)
|
torch.exp(attn_scores, out=attn_scores)
|
||||||
|
|||||||
@@ -149,9 +149,6 @@ class Attention(nn.Module):
|
|||||||
seq_img = hidden_states.shape[1]
|
seq_img = hidden_states.shape[1]
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
transformer_patches = transformer_options.get("patches", {})
|
|
||||||
extra_options = transformer_options.copy()
|
|
||||||
|
|
||||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
@@ -170,22 +167,15 @@ class Attention(nn.Module):
|
|||||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||||
|
|
||||||
|
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||||
|
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
if encoder_hidden_states_mask is not None:
|
if encoder_hidden_states_mask is not None:
|
||||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
|
|
||||||
if "attn1_patch" in transformer_patches:
|
|
||||||
patch = transformer_patches["attn1_patch"]
|
|
||||||
for p in patch:
|
|
||||||
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
|
|
||||||
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
|
|
||||||
|
|
||||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
|
||||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||||
attn_mask, transformer_options=transformer_options,
|
attn_mask, transformer_options=transformer_options,
|
||||||
skip_reshape=True)
|
skip_reshape=True)
|
||||||
@@ -454,7 +444,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
timestep_zero_index = None
|
timestep_zero_index = None
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
ref_num_tokens = []
|
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
@@ -485,16 +474,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
ref_num_tokens.append(kontext.shape[1])
|
|
||||||
if timestep_zero:
|
if timestep_zero:
|
||||||
if index > 0:
|
if index > 0:
|
||||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||||
timestep_zero_index = num_embeds
|
timestep_zero_index = num_embeds
|
||||||
transformer_options = transformer_options.copy()
|
|
||||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
@@ -506,18 +495,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
if "post_input" in patches:
|
|
||||||
for p in patches["post_input"]:
|
|
||||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
|
||||||
hidden_states = out["img"]
|
|
||||||
encoder_hidden_states = out["txt"]
|
|
||||||
img_ids = out["img_ids"]
|
|
||||||
txt_ids = out["txt_ids"]
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
|
||||||
del ids, txt_ids, img_ids
|
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
|||||||
@@ -99,9 +99,6 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
|
||||||
if tp > 0 and not k.startswith("clip_"):
|
|
||||||
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import comfy.memory_management
|
|
||||||
import comfy.supported_models
|
import comfy.supported_models
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@@ -1119,13 +1118,8 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
new[:old_weight.shape[0]] = old_weight
|
new[:old_weight.shape[0]] = old_weight
|
||||||
old_weight = new
|
old_weight = new
|
||||||
|
|
||||||
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
|
|
||||||
old_weight = old_weight.clone()
|
|
||||||
|
|
||||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||||
else:
|
else:
|
||||||
if comfy.memory_management.aimdo_enabled:
|
|
||||||
weight = weight.clone()
|
|
||||||
old_weight = weight
|
old_weight = weight
|
||||||
w = weight
|
w = weight
|
||||||
w[:] = fun(weight)
|
w[:] = fun(weight)
|
||||||
|
|||||||
@@ -270,23 +270,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
OOM_EXCEPTION = Exception
|
OOM_EXCEPTION = Exception
|
||||||
|
|
||||||
try:
|
|
||||||
ACCELERATOR_ERROR = torch.AcceleratorError
|
|
||||||
except AttributeError:
|
|
||||||
ACCELERATOR_ERROR = RuntimeError
|
|
||||||
|
|
||||||
def is_oom(e):
|
|
||||||
if isinstance(e, OOM_EXCEPTION):
|
|
||||||
return True
|
|
||||||
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
|
|
||||||
discard_cuda_async_error()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def raise_non_oom(e):
|
|
||||||
if not is_oom(e):
|
|
||||||
raise e
|
|
||||||
|
|
||||||
XFORMERS_VERSION = ""
|
XFORMERS_VERSION = ""
|
||||||
XFORMERS_ENABLED_VAE = True
|
XFORMERS_ENABLED_VAE = True
|
||||||
if args.disable_xformers:
|
if args.disable_xformers:
|
||||||
@@ -1280,7 +1263,7 @@ def discard_cuda_async_error():
|
|||||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
_ = a + b
|
_ = a + b
|
||||||
synchronize()
|
synchronize()
|
||||||
except RuntimeError:
|
except torch.AcceleratorError:
|
||||||
#Dump it! We already know about it from the synchronous return
|
#Dump it! We already know about it from the synchronous return
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -599,27 +599,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def model_patches_call_function(self, function_name="cleanup", arguments={}):
|
|
||||||
to = self.model_options["transformer_options"]
|
|
||||||
if "patches" in to:
|
|
||||||
patches = to["patches"]
|
|
||||||
for name in patches:
|
|
||||||
patch_list = patches[name]
|
|
||||||
for i in range(len(patch_list)):
|
|
||||||
if hasattr(patch_list[i], function_name):
|
|
||||||
getattr(patch_list[i], function_name)(**arguments)
|
|
||||||
if "patches_replace" in to:
|
|
||||||
patches = to["patches_replace"]
|
|
||||||
for name in patches:
|
|
||||||
patch_list = patches[name]
|
|
||||||
for k in patch_list:
|
|
||||||
if hasattr(patch_list[k], function_name):
|
|
||||||
getattr(patch_list[k], function_name)(**arguments)
|
|
||||||
if "model_function_wrapper" in self.model_options:
|
|
||||||
wrap_func = self.model_options["model_function_wrapper"]
|
|
||||||
if hasattr(wrap_func, function_name):
|
|
||||||
getattr(wrap_func, function_name)(**arguments)
|
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
if hasattr(self.model, "get_dtype"):
|
if hasattr(self.model, "get_dtype"):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
@@ -1083,7 +1062,6 @@ class ModelPatcher:
|
|||||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.model_patches_call_function(function_name="cleanup")
|
|
||||||
self.clean_hooks()
|
self.clean_hooks()
|
||||||
if hasattr(self.model, "current_patcher"):
|
if hasattr(self.model, "current_patcher"):
|
||||||
self.model.current_patcher = None
|
self.model.current_patcher = None
|
||||||
|
|||||||
@@ -954,8 +954,7 @@ class VAE:
|
|||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except Exception as e:
|
except model_management.OOM_EXCEPTION:
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
@@ -1030,8 +1029,7 @@ class VAE:
|
|||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except Exception as e:
|
except model_management.OOM_EXCEPTION:
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class RevePostprocessingOperation(BaseModel):
|
|
||||||
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
|
|
||||||
upscale_factor: int | None = Field(
|
|
||||||
None,
|
|
||||||
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
|
|
||||||
ge=2,
|
|
||||||
le=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageCreateRequest(BaseModel):
|
|
||||||
prompt: str = Field(...)
|
|
||||||
aspect_ratio: str | None = Field(...)
|
|
||||||
version: str = Field(...)
|
|
||||||
test_time_scaling: int = Field(
|
|
||||||
...,
|
|
||||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
|
||||||
ge=1,
|
|
||||||
le=15,
|
|
||||||
)
|
|
||||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
|
||||||
None, description="Optional postprocessing operations to apply after generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageEditRequest(BaseModel):
|
|
||||||
edit_instruction: str = Field(...)
|
|
||||||
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
|
|
||||||
aspect_ratio: str | None = Field(...)
|
|
||||||
version: str = Field(...)
|
|
||||||
test_time_scaling: int | None = Field(
|
|
||||||
...,
|
|
||||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
|
||||||
ge=1,
|
|
||||||
le=15,
|
|
||||||
)
|
|
||||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
|
||||||
None, description="Optional postprocessing operations to apply after generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageRemixRequest(BaseModel):
|
|
||||||
prompt: str = Field(...)
|
|
||||||
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
|
|
||||||
aspect_ratio: str | None = Field(...)
|
|
||||||
version: str = Field(...)
|
|
||||||
test_time_scaling: int | None = Field(
|
|
||||||
...,
|
|
||||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
|
||||||
ge=1,
|
|
||||||
le=15,
|
|
||||||
)
|
|
||||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
|
||||||
None, description="Optional postprocessing operations to apply after generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageResponse(BaseModel):
|
|
||||||
image: str | None = Field(None, description="The base64 encoded image data.")
|
|
||||||
request_id: str | None = Field(None, description="A unique id for the request.")
|
|
||||||
credits_used: float | None = Field(None, description="The number of credits used for this request.")
|
|
||||||
version: str | None = Field(None, description="The specific model version used.")
|
|
||||||
content_violation: bool | None = Field(
|
|
||||||
None, description="Indicates whether the generated image violates the content policy."
|
|
||||||
)
|
|
||||||
@@ -1,395 +0,0 @@
|
|||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
|
||||||
from comfy_api_nodes.apis.reve import (
|
|
||||||
ReveImageCreateRequest,
|
|
||||||
ReveImageEditRequest,
|
|
||||||
ReveImageRemixRequest,
|
|
||||||
RevePostprocessingOperation,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util import (
|
|
||||||
ApiEndpoint,
|
|
||||||
bytesio_to_image_tensor,
|
|
||||||
sync_op_raw,
|
|
||||||
tensor_to_base64_string,
|
|
||||||
validate_string,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
|
|
||||||
ops = []
|
|
||||||
if upscale["upscale"] == "enabled":
|
|
||||||
ops.append(
|
|
||||||
RevePostprocessingOperation(
|
|
||||||
process="upscale",
|
|
||||||
upscale_factor=upscale["upscale_factor"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if remove_background:
|
|
||||||
ops.append(RevePostprocessingOperation(process="remove_background"))
|
|
||||||
return ops or None
|
|
||||||
|
|
||||||
|
|
||||||
def _postprocessing_inputs():
|
|
||||||
return [
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"upscale",
|
|
||||||
options=[
|
|
||||||
IO.DynamicCombo.Option("disabled", []),
|
|
||||||
IO.DynamicCombo.Option(
|
|
||||||
"enabled",
|
|
||||||
[
|
|
||||||
IO.Int.Input(
|
|
||||||
"upscale_factor",
|
|
||||||
default=2,
|
|
||||||
min=2,
|
|
||||||
max=4,
|
|
||||||
step=1,
|
|
||||||
tooltip="Upscale factor (2x, 3x, or 4x).",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tooltip="Upscale the generated image. May add additional cost.",
|
|
||||||
),
|
|
||||||
IO.Boolean.Input(
|
|
||||||
"remove_background",
|
|
||||||
default=False,
|
|
||||||
tooltip="Remove the background from the generated image. May add additional cost.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _reve_price_extractor(headers: dict) -> float | None:
|
|
||||||
credits_used = headers.get("x-reve-credits-used")
|
|
||||||
if credits_used is not None:
|
|
||||||
return float(credits_used) / 524.48
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _reve_response_header_validator(headers: dict) -> None:
|
|
||||||
error_code = headers.get("x-reve-error-code")
|
|
||||||
if error_code:
|
|
||||||
raise ValueError(f"Reve API error: {error_code}")
|
|
||||||
if headers.get("x-reve-content-violation", "").lower() == "true":
|
|
||||||
raise ValueError("The generated image was flagged for content policy violation.")
|
|
||||||
|
|
||||||
|
|
||||||
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
|
|
||||||
return [
|
|
||||||
IO.DynamicCombo.Option(
|
|
||||||
version,
|
|
||||||
[
|
|
||||||
IO.Combo.Input(
|
|
||||||
"aspect_ratio",
|
|
||||||
options=aspect_ratios,
|
|
||||||
tooltip="Aspect ratio of the output image.",
|
|
||||||
),
|
|
||||||
IO.Int.Input(
|
|
||||||
"test_time_scaling",
|
|
||||||
default=1,
|
|
||||||
min=1,
|
|
||||||
max=5,
|
|
||||||
step=1,
|
|
||||||
tooltip="Higher values produce better images but cost more credits.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
for version in versions
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageCreateNode(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ReveImageCreateNode",
|
|
||||||
display_name="Reve Image Create",
|
|
||||||
category="api node/image/Reve",
|
|
||||||
description="Generate images from text descriptions using Reve.",
|
|
||||||
inputs=[
|
|
||||||
IO.String.Input(
|
|
||||||
"prompt",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Text description of the desired image. Maximum 2560 characters.",
|
|
||||||
),
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"model",
|
|
||||||
options=_model_inputs(
|
|
||||||
["reve-create@20250915"],
|
|
||||||
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
|
|
||||||
),
|
|
||||||
tooltip="Model version to use for generation.",
|
|
||||||
),
|
|
||||||
*_postprocessing_inputs(),
|
|
||||||
IO.Int.Input(
|
|
||||||
"seed",
|
|
||||||
default=0,
|
|
||||||
min=0,
|
|
||||||
max=2147483647,
|
|
||||||
control_after_generate=True,
|
|
||||||
tooltip="Seed controls whether the node should re-run; "
|
|
||||||
"results are non-deterministic regardless of seed.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Image.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
prompt: str,
|
|
||||||
model: dict,
|
|
||||||
upscale: dict,
|
|
||||||
remove_background: bool,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(prompt, min_length=1, max_length=2560)
|
|
||||||
response = await sync_op_raw(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(
|
|
||||||
path="/proxy/reve/v1/image/create",
|
|
||||||
method="POST",
|
|
||||||
headers={"Accept": "image/webp"},
|
|
||||||
),
|
|
||||||
as_binary=True,
|
|
||||||
price_extractor=_reve_price_extractor,
|
|
||||||
response_header_validator=_reve_response_header_validator,
|
|
||||||
data=ReveImageCreateRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
aspect_ratio=model["aspect_ratio"],
|
|
||||||
version=model["model"],
|
|
||||||
test_time_scaling=model["test_time_scaling"],
|
|
||||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageEditNode(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ReveImageEditNode",
|
|
||||||
display_name="Reve Image Edit",
|
|
||||||
category="api node/image/Reve",
|
|
||||||
description="Edit images using natural language instructions with Reve.",
|
|
||||||
inputs=[
|
|
||||||
IO.Image.Input("image", tooltip="The image to edit."),
|
|
||||||
IO.String.Input(
|
|
||||||
"edit_instruction",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
|
|
||||||
),
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"model",
|
|
||||||
options=_model_inputs(
|
|
||||||
["reve-edit@20250915", "reve-edit-fast@20251030"],
|
|
||||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
|
||||||
),
|
|
||||||
tooltip="Model version to use for editing.",
|
|
||||||
),
|
|
||||||
*_postprocessing_inputs(),
|
|
||||||
IO.Int.Input(
|
|
||||||
"seed",
|
|
||||||
default=0,
|
|
||||||
min=0,
|
|
||||||
max=2147483647,
|
|
||||||
control_after_generate=True,
|
|
||||||
tooltip="Seed controls whether the node should re-run; "
|
|
||||||
"results are non-deterministic regardless of seed.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Image.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
depends_on=IO.PriceBadgeDepends(
|
|
||||||
widgets=["model"],
|
|
||||||
),
|
|
||||||
expr="""
|
|
||||||
(
|
|
||||||
$isFast := $contains(widgets.model, "fast");
|
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
image: Input.Image,
|
|
||||||
edit_instruction: str,
|
|
||||||
model: dict,
|
|
||||||
upscale: dict,
|
|
||||||
remove_background: bool,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(edit_instruction, min_length=1, max_length=2560)
|
|
||||||
tts = model["test_time_scaling"]
|
|
||||||
ar = model["aspect_ratio"]
|
|
||||||
response = await sync_op_raw(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(
|
|
||||||
path="/proxy/reve/v1/image/edit",
|
|
||||||
method="POST",
|
|
||||||
headers={"Accept": "image/webp"},
|
|
||||||
),
|
|
||||||
as_binary=True,
|
|
||||||
price_extractor=_reve_price_extractor,
|
|
||||||
response_header_validator=_reve_response_header_validator,
|
|
||||||
data=ReveImageEditRequest(
|
|
||||||
edit_instruction=edit_instruction,
|
|
||||||
reference_image=tensor_to_base64_string(image),
|
|
||||||
aspect_ratio=ar if ar != "auto" else None,
|
|
||||||
version=model["model"],
|
|
||||||
test_time_scaling=tts if tts and tts > 1 else None,
|
|
||||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
|
||||||
|
|
||||||
|
|
||||||
class ReveImageRemixNode(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ReveImageRemixNode",
|
|
||||||
display_name="Reve Image Remix",
|
|
||||||
category="api node/image/Reve",
|
|
||||||
description="Combine reference images with text prompts to create new images using Reve.",
|
|
||||||
inputs=[
|
|
||||||
IO.Autogrow.Input(
|
|
||||||
"reference_images",
|
|
||||||
template=IO.Autogrow.TemplatePrefix(
|
|
||||||
IO.Image.Input("image"),
|
|
||||||
prefix="image_",
|
|
||||||
min=1,
|
|
||||||
max=6,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
IO.String.Input(
|
|
||||||
"prompt",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Text description of the desired image. "
|
|
||||||
"May include XML img tags to reference specific images by index, "
|
|
||||||
"e.g. <img>0</img>, <img>1</img>, etc.",
|
|
||||||
),
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"model",
|
|
||||||
options=_model_inputs(
|
|
||||||
["reve-remix@20250915", "reve-remix-fast@20251030"],
|
|
||||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
|
||||||
),
|
|
||||||
tooltip="Model version to use for remixing.",
|
|
||||||
),
|
|
||||||
*_postprocessing_inputs(),
|
|
||||||
IO.Int.Input(
|
|
||||||
"seed",
|
|
||||||
default=0,
|
|
||||||
min=0,
|
|
||||||
max=2147483647,
|
|
||||||
control_after_generate=True,
|
|
||||||
tooltip="Seed controls whether the node should re-run; "
|
|
||||||
"results are non-deterministic regardless of seed.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[IO.Image.Output()],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
depends_on=IO.PriceBadgeDepends(
|
|
||||||
widgets=["model"],
|
|
||||||
),
|
|
||||||
expr="""
|
|
||||||
(
|
|
||||||
$isFast := $contains(widgets.model, "fast");
|
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
reference_images: IO.Autogrow.Type,
|
|
||||||
prompt: str,
|
|
||||||
model: dict,
|
|
||||||
upscale: dict,
|
|
||||||
remove_background: bool,
|
|
||||||
seed: int,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(prompt, min_length=1, max_length=2560)
|
|
||||||
if not reference_images:
|
|
||||||
raise ValueError("At least one reference image is required.")
|
|
||||||
ref_base64_list = []
|
|
||||||
for key in reference_images:
|
|
||||||
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
|
|
||||||
if len(ref_base64_list) > 6:
|
|
||||||
raise ValueError("Maximum 6 reference images are allowed.")
|
|
||||||
tts = model["test_time_scaling"]
|
|
||||||
ar = model["aspect_ratio"]
|
|
||||||
response = await sync_op_raw(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(
|
|
||||||
path="/proxy/reve/v1/image/remix",
|
|
||||||
method="POST",
|
|
||||||
headers={"Accept": "image/webp"},
|
|
||||||
),
|
|
||||||
as_binary=True,
|
|
||||||
price_extractor=_reve_price_extractor,
|
|
||||||
response_header_validator=_reve_response_header_validator,
|
|
||||||
data=ReveImageRemixRequest(
|
|
||||||
prompt=prompt,
|
|
||||||
reference_images=ref_base64_list,
|
|
||||||
aspect_ratio=ar if ar != "auto" else None,
|
|
||||||
version=model["model"],
|
|
||||||
test_time_scaling=tts if tts and tts > 1 else None,
|
|
||||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
|
||||||
|
|
||||||
|
|
||||||
class ReveExtension(ComfyExtension):
|
|
||||||
@override
|
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
|
||||||
return [
|
|
||||||
ReveImageCreateNode,
|
|
||||||
ReveImageEditNode,
|
|
||||||
ReveImageRemixNode,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> ReveExtension:
|
|
||||||
return ReveExtension()
|
|
||||||
@@ -67,7 +67,6 @@ class _RequestConfig:
|
|||||||
progress_origin_ts: float | None = None
|
progress_origin_ts: float | None = None
|
||||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||||
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -203,13 +202,11 @@ async def sync_op_raw(
|
|||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
max_retries_on_rate_limit: int = 16,
|
max_retries_on_rate_limit: int = 16,
|
||||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
|
||||||
) -> dict[str, Any] | bytes:
|
) -> dict[str, Any] | bytes:
|
||||||
"""
|
"""
|
||||||
Make a single network request.
|
Make a single network request.
|
||||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||||
- If as_binary=True: returns bytes.
|
- If as_binary=True: returns bytes.
|
||||||
- response_header_validator: optional callback receiving response headers dict
|
|
||||||
"""
|
"""
|
||||||
if isinstance(data, BaseModel):
|
if isinstance(data, BaseModel):
|
||||||
data = data.model_dump(exclude_none=True)
|
data = data.model_dump(exclude_none=True)
|
||||||
@@ -235,7 +232,6 @@ async def sync_op_raw(
|
|||||||
price_extractor=price_extractor,
|
price_extractor=price_extractor,
|
||||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||||
is_rate_limited=is_rate_limited,
|
is_rate_limited=is_rate_limited,
|
||||||
response_header_validator=response_header_validator,
|
|
||||||
)
|
)
|
||||||
return await _request_base(cfg, expect_binary=as_binary)
|
return await _request_base(cfg, expect_binary=as_binary)
|
||||||
|
|
||||||
@@ -773,12 +769,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||||
)
|
)
|
||||||
bytes_payload = bytes(buff)
|
bytes_payload = bytes(buff)
|
||||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
|
||||||
if cfg.price_extractor:
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
extracted_price = cfg.price_extractor(resp_headers)
|
|
||||||
if cfg.response_header_validator:
|
|
||||||
cfg.response_header_validator(resp_headers)
|
|
||||||
operation_succeeded = True
|
operation_succeeded = True
|
||||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
@@ -786,7 +776,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
request_method=method,
|
request_method=method,
|
||||||
request_url=url,
|
request_url=url,
|
||||||
response_status_code=resp.status,
|
response_status_code=resp.status,
|
||||||
response_headers=resp_headers,
|
response_headers=dict(resp.headers),
|
||||||
response_content=bytes_payload,
|
response_content=bytes_payload,
|
||||||
)
|
)
|
||||||
return bytes_payload
|
return bytes_payload
|
||||||
|
|||||||
@@ -1,32 +1,32 @@
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
|
||||||
from typing_extensions import override
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
class LTXVLatentUpsampler:
|
||||||
class LTXVLatentUpsampler(IO.ComfyNode):
|
|
||||||
"""
|
"""
|
||||||
Upsamples a video latent by a factor of 2.
|
Upsamples a video latent by a factor of 2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def INPUT_TYPES(s):
|
||||||
return IO.Schema(
|
return {
|
||||||
node_id="LTXVLatentUpsampler",
|
"required": {
|
||||||
category="latent/video",
|
"samples": ("LATENT",),
|
||||||
is_experimental=True,
|
"upscale_model": ("LATENT_UPSCALE_MODEL",),
|
||||||
inputs=[
|
"vae": ("VAE",),
|
||||||
IO.Latent.Input("samples"),
|
}
|
||||||
IO.LatentUpscaleModel.Input("upscale_model"),
|
}
|
||||||
IO.Vae.Input("vae"),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
IO.Latent.Output(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
RETURN_TYPES = ("LATENT",)
|
||||||
def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput:
|
FUNCTION = "upsample_latent"
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def upsample_latent(
|
||||||
|
self,
|
||||||
|
samples: dict,
|
||||||
|
upscale_model,
|
||||||
|
vae,
|
||||||
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Upsample the input latent using the provided model.
|
Upsample the input latent using the provided model.
|
||||||
|
|
||||||
@@ -34,6 +34,7 @@ class LTXVLatentUpsampler(IO.ComfyNode):
|
|||||||
samples (dict): Input latent samples
|
samples (dict): Input latent samples
|
||||||
upscale_model (LatentUpsampler): Loaded upscale model
|
upscale_model (LatentUpsampler): Loaded upscale model
|
||||||
vae: VAE model for normalization
|
vae: VAE model for normalization
|
||||||
|
auto_tiling (bool): Whether to automatically tile the input for processing
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: Tuple containing the upsampled latent
|
tuple: Tuple containing the upsampled latent
|
||||||
@@ -66,16 +67,9 @@ class LTXVLatentUpsampler(IO.ComfyNode):
|
|||||||
return_dict = samples.copy()
|
return_dict = samples.copy()
|
||||||
return_dict["samples"] = upsampled_latents
|
return_dict["samples"] = upsampled_latents
|
||||||
return_dict.pop("noise_mask", None)
|
return_dict.pop("noise_mask", None)
|
||||||
return IO.NodeOutput(return_dict)
|
return (return_dict,)
|
||||||
|
|
||||||
upsample_latent = execute # TODO: remove
|
|
||||||
|
|
||||||
|
|
||||||
class LTXVLatentUpsamplerExtension(ComfyExtension):
|
NODE_CLASS_MAPPINGS = {
|
||||||
@override
|
"LTXVLatentUpsampler": LTXVLatentUpsampler,
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
}
|
||||||
return [LTXVLatentUpsampler]
|
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension:
|
|
||||||
return LTXVLatentUpsamplerExtension()
|
|
||||||
|
|||||||
@@ -86,8 +86,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||||
oom = False
|
oom = False
|
||||||
except Exception as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
model_management.raise_non_oom(e)
|
|
||||||
tile //= 2
|
tile //= 2
|
||||||
if tile < 128:
|
if tile < 128:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
tips = ""
|
tips = ""
|
||||||
|
|
||||||
if comfy.model_management.is_oom(ex):
|
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
|
|||||||
24
main.py
24
main.py
@@ -3,16 +3,14 @@ comfy.options.enable_args_parsing()
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import shutil
|
|
||||||
import importlib.metadata
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
from app.assets.seeder import asset_seeder
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
from utils.mime_types import init_mime_types
|
from utils.mime_types import init_mime_types
|
||||||
import faulthandler
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from comfy_execution.progress import get_progress_state
|
from comfy_execution.progress import get_progress_state
|
||||||
@@ -27,8 +25,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
|
||||||
|
|
||||||
import comfy_aimdo.control
|
import comfy_aimdo.control
|
||||||
|
|
||||||
if enables_dynamic_vram():
|
if enables_dynamic_vram():
|
||||||
@@ -68,15 +64,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
|
|
||||||
def handle_comfyui_manager_unavailable():
|
def handle_comfyui_manager_unavailable():
|
||||||
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
|
if not args.windows_standalone_build:
|
||||||
uv_available = shutil.which("uv") is not None
|
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
|
||||||
|
|
||||||
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
|
|
||||||
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
|
|
||||||
if uv_available:
|
|
||||||
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
|
|
||||||
msg += "\n"
|
|
||||||
logging.warning(msg)
|
|
||||||
args.enable_manager = False
|
args.enable_manager = False
|
||||||
|
|
||||||
|
|
||||||
@@ -184,6 +173,7 @@ execute_prestartup_script()
|
|||||||
|
|
||||||
# Main code
|
# Main code
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
@@ -192,7 +182,6 @@ if 'torch' in sys.modules:
|
|||||||
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from app.assets.seeder import asset_seeder
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@@ -462,11 +451,6 @@ if __name__ == "__main__":
|
|||||||
# Running directly, just start ComfyUI.
|
# Running directly, just start ComfyUI.
|
||||||
logging.info("Python version: {}".format(sys.version))
|
logging.info("Python version: {}".format(sys.version))
|
||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
for package in ("comfy-aimdo", "comfy-kitchen"):
|
|
||||||
try:
|
|
||||||
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b2
|
comfyui_manager==4.1b1
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.39.19
|
comfyui-frontend-package==1.39.19
|
||||||
comfyui-workflow-templates==0.9.18
|
comfyui-workflow-templates==0.9.11
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@@ -23,7 +23,7 @@ SQLAlchemy
|
|||||||
filelock
|
filelock
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.7
|
comfy-kitchen>=0.2.7
|
||||||
comfy-aimdo>=0.2.10
|
comfy-aimdo>=0.2.7
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
Reference in New Issue
Block a user