mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 00:59:59 +00:00
Compare commits
1 Commits
fix/static
...
luke-mino-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42eda2b6fc |
@@ -92,6 +92,7 @@ class _AssetSeeder:
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
self._pending_enrich: dict | None = None
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
@@ -196,6 +197,42 @@ class _AssetSeeder:
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def enqueue_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||
|
||||
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||
request is stored and will run automatically when the current scan
|
||||
finishes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if started immediately, False if queued for later
|
||||
"""
|
||||
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||
return True
|
||||
with self._lock:
|
||||
if self._pending_enrich is not None:
|
||||
existing_roots = set(self._pending_enrich["roots"])
|
||||
existing_roots.update(roots)
|
||||
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||
self._pending_enrich["compute_hashes"] = (
|
||||
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||
)
|
||||
else:
|
||||
self._pending_enrich = {
|
||||
"roots": roots,
|
||||
"compute_hashes": compute_hashes,
|
||||
}
|
||||
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||
return False
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
@@ -381,9 +418,13 @@ class _AssetSeeder:
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
|
||||
def _reset_to_idle(self) -> None:
|
||||
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
@@ -594,9 +635,14 @@ class _AssetSeeder:
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
pending = self._pending_enrich
|
||||
self._pending_enrich = None
|
||||
if pending is not None:
|
||||
self.start_enrich(
|
||||
roots=pending["roots"],
|
||||
compute_hashes=pending["compute_hashes"],
|
||||
)
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
ingest_existing_file,
|
||||
register_output_files,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
@@ -72,6 +74,8 @@ __all__ = [
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"ingest_existing_file",
|
||||
"register_output_files",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
|
||||
@@ -23,9 +23,11 @@ from app.assets.database.queries import (
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
@@ -128,6 +130,59 @@ def _ingest_file_from_path(
|
||||
)
|
||||
|
||||
|
||||
def register_output_files(
|
||||
file_paths: Sequence[str],
|
||||
user_metadata: UserMetadata = None,
|
||||
) -> int:
|
||||
"""Register a batch of output file paths as assets.
|
||||
|
||||
Returns the number of files successfully registered.
|
||||
"""
|
||||
registered = 0
|
||||
for abs_path in file_paths:
|
||||
if not os.path.isfile(abs_path):
|
||||
continue
|
||||
try:
|
||||
ingest_existing_file(abs_path, user_metadata=user_metadata)
|
||||
registered += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to register output: %s", abs_path)
|
||||
return registered
|
||||
|
||||
|
||||
def ingest_existing_file(
|
||||
abs_path: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
extra_tags: Sequence[str] = (),
|
||||
owner_id: str = "",
|
||||
) -> None:
|
||||
"""Register an existing on-disk file as an asset stub.
|
||||
|
||||
Inserts a stub record (hash=NULL) for immediate UX visibility.
|
||||
The caller is responsible for triggering background enrichment
|
||||
(hash computation, metadata extraction) via the asset seeder.
|
||||
"""
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||
|
||||
spec = {
|
||||
"abs_path": abs_path,
|
||||
"size_bytes": size_bytes,
|
||||
"mtime_ns": mtime_ns,
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": os.path.basename(abs_path),
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
with create_session() as session:
|
||||
batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||
session.commit()
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
|
||||
@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[dict]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor * m_mult
|
||||
else:
|
||||
for d in modulation_dims:
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||
if m_add is not None:
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||
return tensor
|
||||
|
||||
|
||||
@@ -223,19 +223,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
|
||||
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
@@ -328,12 +321,6 @@ class SingleStreamBlock(nn.Module):
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
@@ -31,8 +31,6 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
|
||||
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
@@ -44,22 +44,6 @@ class FluxParams:
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@@ -154,7 +138,6 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
@@ -181,9 +164,13 @@ class Flux(nn.Module):
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
@@ -195,24 +182,6 @@ class Flux(nn.Module):
|
||||
else:
|
||||
pe = None
|
||||
|
||||
vec_orig = vec
|
||||
txt_vec = vec
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
modulation_dims = []
|
||||
batch = vec.shape[0] // 2
|
||||
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
|
||||
invert = invert_slices(timestep_zero_index, img.shape[1])
|
||||
for s in invert:
|
||||
modulation_dims.append((s[0], s[1], 0))
|
||||
for s in timestep_zero_index:
|
||||
modulation_dims.append((s[0], s[1], 1))
|
||||
extra_kwargs["modulation_dims_img"] = modulation_dims
|
||||
txt_vec = vec[:batch]
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
@@ -226,8 +195,7 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
@@ -245,8 +213,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options,
|
||||
**extra_kwargs)
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -263,12 +230,6 @@ class Flux(nn.Module):
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
lambda a: 0 if a == 0 else a + txt.shape[1]
|
||||
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
|
||||
extra_kwargs["modulation_dims"] = modulation_dims_combined
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
@@ -281,8 +242,7 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
@@ -293,7 +253,7 @@ class Flux(nn.Module):
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -304,11 +264,7 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
extra_kwargs["modulation_dims"] = modulation_dims
|
||||
|
||||
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
@@ -356,16 +312,13 @@ class Flux(nn.Module):
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
timestep_zero = ref_latents_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method in ("index", "index_timestep_zero"):
|
||||
if ref_latents_method == "index":
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
@@ -389,13 +342,6 @@ class Flux(nn.Module):
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
@@ -403,6 +349,6 @@ class Flux(nn.Module):
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@@ -372,8 +372,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
|
||||
@@ -258,8 +258,7 @@ def slice_attention(q, k, v):
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
model_management.soft_empty_cache(True)
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
@@ -315,8 +314,7 @@ def pytorch_attention(q, k, v):
|
||||
try:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
oom_fallback = True
|
||||
if oom_fallback:
|
||||
|
||||
@@ -169,8 +169,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
|
||||
@@ -149,9 +149,6 @@ class Attention(nn.Module):
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
@@ -170,22 +167,15 @@ class Attention(nn.Module):
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
if encoder_hidden_states_mask is not None:
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
|
||||
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
|
||||
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attn_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
@@ -454,7 +444,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
@@ -485,16 +474,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = num_embeds
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
@@ -506,18 +495,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
import comfy.utils
|
||||
@@ -1119,13 +1118,8 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
new[:old_weight.shape[0]] = old_weight
|
||||
old_weight = new
|
||||
|
||||
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
|
||||
old_weight = old_weight.clone()
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
weight = weight.clone()
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
|
||||
@@ -270,23 +270,6 @@ try:
|
||||
except:
|
||||
OOM_EXCEPTION = Exception
|
||||
|
||||
try:
|
||||
ACCELERATOR_ERROR = torch.AcceleratorError
|
||||
except AttributeError:
|
||||
ACCELERATOR_ERROR = RuntimeError
|
||||
|
||||
def is_oom(e):
|
||||
if isinstance(e, OOM_EXCEPTION):
|
||||
return True
|
||||
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
|
||||
discard_cuda_async_error()
|
||||
return True
|
||||
return False
|
||||
|
||||
def raise_non_oom(e):
|
||||
if not is_oom(e):
|
||||
raise e
|
||||
|
||||
XFORMERS_VERSION = ""
|
||||
XFORMERS_ENABLED_VAE = True
|
||||
if args.disable_xformers:
|
||||
@@ -1280,7 +1263,7 @@ def discard_cuda_async_error():
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
synchronize()
|
||||
except RuntimeError:
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
|
||||
@@ -599,27 +599,6 @@ class ModelPatcher:
|
||||
|
||||
return models
|
||||
|
||||
def model_patches_call_function(self, function_name="cleanup", arguments={}):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], function_name):
|
||||
getattr(patch_list[i], function_name)(**arguments)
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], function_name):
|
||||
getattr(patch_list[k], function_name)(**arguments)
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, function_name):
|
||||
getattr(wrap_func, function_name)(**arguments)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
@@ -1083,7 +1062,6 @@ class ModelPatcher:
|
||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||
|
||||
def cleanup(self):
|
||||
self.model_patches_call_function(function_name="cleanup")
|
||||
self.clean_hooks()
|
||||
if hasattr(self.model, "current_patcher"):
|
||||
self.model.current_patcher = None
|
||||
|
||||
@@ -954,8 +954,7 @@ class VAE:
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
@@ -1030,8 +1029,7 @@ class VAE:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
|
||||
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
|
||||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RevePostprocessingOperation(BaseModel):
|
||||
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
|
||||
upscale_factor: int | None = Field(
|
||||
None,
|
||||
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
|
||||
ge=2,
|
||||
le=4,
|
||||
)
|
||||
|
||||
|
||||
class ReveImageCreateRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageEditRequest(BaseModel):
|
||||
edit_instruction: str = Field(...)
|
||||
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int | None = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageRemixRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int | None = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageResponse(BaseModel):
|
||||
image: str | None = Field(None, description="The base64 encoded image data.")
|
||||
request_id: str | None = Field(None, description="A unique id for the request.")
|
||||
credits_used: float | None = Field(None, description="The number of credits used for this request.")
|
||||
version: str | None = Field(None, description="The specific model version used.")
|
||||
content_violation: bool | None = Field(
|
||||
None, description="Indicates whether the generated image violates the content policy."
|
||||
)
|
||||
@@ -1,395 +0,0 @@
|
||||
from io import BytesIO
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.reve import (
|
||||
ReveImageCreateRequest,
|
||||
ReveImageEditRequest,
|
||||
ReveImageRemixRequest,
|
||||
RevePostprocessingOperation,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
sync_op_raw,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
|
||||
ops = []
|
||||
if upscale["upscale"] == "enabled":
|
||||
ops.append(
|
||||
RevePostprocessingOperation(
|
||||
process="upscale",
|
||||
upscale_factor=upscale["upscale_factor"],
|
||||
)
|
||||
)
|
||||
if remove_background:
|
||||
ops.append(RevePostprocessingOperation(process="remove_background"))
|
||||
return ops or None
|
||||
|
||||
|
||||
def _postprocessing_inputs():
|
||||
return [
|
||||
IO.DynamicCombo.Input(
|
||||
"upscale",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"enabled",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"upscale_factor",
|
||||
default=2,
|
||||
min=2,
|
||||
max=4,
|
||||
step=1,
|
||||
tooltip="Upscale factor (2x, 3x, or 4x).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Upscale the generated image. May add additional cost.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"remove_background",
|
||||
default=False,
|
||||
tooltip="Remove the background from the generated image. May add additional cost.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _reve_price_extractor(headers: dict) -> float | None:
|
||||
credits_used = headers.get("x-reve-credits-used")
|
||||
if credits_used is not None:
|
||||
return float(credits_used) / 524.48
|
||||
return None
|
||||
|
||||
|
||||
def _reve_response_header_validator(headers: dict) -> None:
|
||||
error_code = headers.get("x-reve-error-code")
|
||||
if error_code:
|
||||
raise ValueError(f"Reve API error: {error_code}")
|
||||
if headers.get("x-reve-content-violation", "").lower() == "true":
|
||||
raise ValueError("The generated image was flagged for content policy violation.")
|
||||
|
||||
|
||||
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
|
||||
return [
|
||||
IO.DynamicCombo.Option(
|
||||
version,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=aspect_ratios,
|
||||
tooltip="Aspect ratio of the output image.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"test_time_scaling",
|
||||
default=1,
|
||||
min=1,
|
||||
max=5,
|
||||
step=1,
|
||||
tooltip="Higher values produce better images but cost more credits.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
for version in versions
|
||||
]
|
||||
|
||||
|
||||
class ReveImageCreateNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageCreateNode",
|
||||
display_name="Reve Image Create",
|
||||
category="api node/image/Reve",
|
||||
description="Generate images from text descriptions using Reve.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. Maximum 2560 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-create@20250915"],
|
||||
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for generation.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2560)
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/create",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageCreateRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
version=model["model"],
|
||||
test_time_scaling=model["test_time_scaling"],
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageEditNode",
|
||||
display_name="Reve Image Edit",
|
||||
category="api node/image/Reve",
|
||||
description="Edit images using natural language instructions with Reve.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to edit."),
|
||||
IO.String.Input(
|
||||
"edit_instruction",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-edit@20250915", "reve-edit-fast@20251030"],
|
||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for editing.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
edit_instruction: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(edit_instruction, min_length=1, max_length=2560)
|
||||
tts = model["test_time_scaling"]
|
||||
ar = model["aspect_ratio"]
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/edit",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageEditRequest(
|
||||
edit_instruction=edit_instruction,
|
||||
reference_image=tensor_to_base64_string(image),
|
||||
aspect_ratio=ar if ar != "auto" else None,
|
||||
version=model["model"],
|
||||
test_time_scaling=tts if tts and tts > 1 else None,
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveImageRemixNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageRemixNode",
|
||||
display_name="Reve Image Remix",
|
||||
category="api node/image/Reve",
|
||||
description="Combine reference images with text prompts to create new images using Reve.",
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="image_",
|
||||
min=1,
|
||||
max=6,
|
||||
),
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. "
|
||||
"May include XML img tags to reference specific images by index, "
|
||||
"e.g. <img>0</img>, <img>1</img>, etc.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-remix@20250915", "reve-remix-fast@20251030"],
|
||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for remixing.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
reference_images: IO.Autogrow.Type,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2560)
|
||||
if not reference_images:
|
||||
raise ValueError("At least one reference image is required.")
|
||||
ref_base64_list = []
|
||||
for key in reference_images:
|
||||
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
|
||||
if len(ref_base64_list) > 6:
|
||||
raise ValueError("Maximum 6 reference images are allowed.")
|
||||
tts = model["test_time_scaling"]
|
||||
ar = model["aspect_ratio"]
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/remix",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageRemixRequest(
|
||||
prompt=prompt,
|
||||
reference_images=ref_base64_list,
|
||||
aspect_ratio=ar if ar != "auto" else None,
|
||||
version=model["model"],
|
||||
test_time_scaling=tts if tts and tts > 1 else None,
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ReveImageCreateNode,
|
||||
ReveImageEditNode,
|
||||
ReveImageRemixNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ReveExtension:
|
||||
return ReveExtension()
|
||||
@@ -67,7 +67,6 @@ class _RequestConfig:
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -203,13 +202,11 @@ async def sync_op_raw(
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
- If as_binary=True: returns bytes.
|
||||
- response_header_validator: optional callback receiving response headers dict
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
@@ -235,7 +232,6 @@ async def sync_op_raw(
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
response_header_validator=response_header_validator,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@@ -773,12 +769,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||
if cfg.price_extractor:
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(resp_headers)
|
||||
if cfg.response_header_validator:
|
||||
cfg.response_header_validator(resp_headers)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
@@ -786,7 +776,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=resp_headers,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
return bytes_payload
|
||||
|
||||
@@ -6,7 +6,6 @@ import comfy.model_management
|
||||
import torch
|
||||
import math
|
||||
import nodes
|
||||
import comfy.ldm.flux.math
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -232,68 +231,6 @@ class Flux2Scheduler(io.ComfyNode):
|
||||
sigmas = get_schedule(steps, round(seq_len))
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class KV_Attn_Input:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def __call__(self, q, k, v, extra_options, **kwargs):
|
||||
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
||||
if len(reference_image_num_tokens) == 0:
|
||||
return {}
|
||||
|
||||
ref_toks = sum(reference_image_num_tokens)
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
|
||||
self.set_cache = True
|
||||
return {"q": q, "k": k, "v": v}
|
||||
|
||||
def cleanup(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
class FluxKVCache(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FluxKVCache",
|
||||
display_name="Flux KV Cache",
|
||||
description="Enables KV Cache optimization for reference images on Flux family models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
input_patch_obj = KV_Attn_Input()
|
||||
|
||||
def model_input_patch(inputs):
|
||||
if len(input_patch_obj.cache) > 0:
|
||||
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
||||
if ref_image_tokens > 0:
|
||||
img = inputs["img"]
|
||||
inputs["img"] = img[:, :-ref_image_tokens]
|
||||
return inputs
|
||||
|
||||
m.set_model_attn1_patch(input_patch_obj)
|
||||
m.set_model_post_input_patch(model_input_patch)
|
||||
if hasattr(model.model.diffusion_model, "params"):
|
||||
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
||||
else:
|
||||
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class FluxExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -306,7 +243,6 @@ class FluxExtension(ComfyExtension):
|
||||
FluxKontextMultiReferenceLatentMethod,
|
||||
EmptyFlux2LatentImage,
|
||||
Flux2Scheduler,
|
||||
FluxKVCache,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io, UI
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0.0, 0.0, 0.0)
|
||||
r = int(hex_color[0:2], 16) / 255.0
|
||||
g = int(hex_color[2:4], 16) / 255.0
|
||||
b = int(hex_color[4:6], 16) / 255.0
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
class PainterNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional base image to paint over",
|
||||
),
|
||||
io.String.Input(
|
||||
"mask",
|
||||
default="",
|
||||
socketless=True,
|
||||
extra_dict={"widgetType": "PAINTER", "image_upload": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Color.Input("bg_color", default="#000000"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("IMAGE"),
|
||||
io.Mask.Output("MASK"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
|
||||
if image is not None:
|
||||
base_image = image[:1]
|
||||
h, w = base_image.shape[1], base_image.shape[2]
|
||||
else:
|
||||
h, w = height, width
|
||||
r, g, b = hex_to_rgb(bg_color)
|
||||
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
|
||||
base_image[0, :, :, 0] = r
|
||||
base_image[0, :, :, 1] = g
|
||||
base_image[0, :, :, 2] = b
|
||||
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
painter_img = node_helpers.pillow(Image.open, mask_path)
|
||||
painter_img = painter_img.convert("RGBA")
|
||||
|
||||
if painter_img.size != (w, h):
|
||||
painter_img = painter_img.resize((w, h), Image.LANCZOS)
|
||||
|
||||
painter_np = np.array(painter_img).astype(np.float32) / 255.0
|
||||
painter_rgb = painter_np[:, :, :3]
|
||||
painter_alpha = painter_np[:, :, 3:4]
|
||||
|
||||
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
|
||||
|
||||
base_np = base_image[0].cpu().numpy()
|
||||
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
|
||||
out_image = torch.from_numpy(composited).unsqueeze(0)
|
||||
else:
|
||||
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
|
||||
out_image = base_image
|
||||
|
||||
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
if os.path.exists(mask_path):
|
||||
m = hashlib.sha256()
|
||||
with open(mask_path, "rb") as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
class PainterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [PainterNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return PainterExtension()
|
||||
@@ -86,8 +86,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
|
||||
@@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
logging.error(traceback.format_exc())
|
||||
tips = ""
|
||||
|
||||
if comfy.model_management.is_oom(ex):
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
|
||||
43
main.py
43
main.py
@@ -3,16 +3,16 @@ comfy.options.enable_args_parsing()
|
||||
|
||||
import os
|
||||
import importlib.util
|
||||
import shutil
|
||||
import importlib.metadata
|
||||
import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
@@ -27,8 +27,6 @@ if __name__ == "__main__":
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
@@ -68,15 +66,8 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
def handle_comfyui_manager_unavailable():
|
||||
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
|
||||
uv_available = shutil.which("uv") is not None
|
||||
|
||||
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
|
||||
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
|
||||
if uv_available:
|
||||
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
|
||||
msg += "\n"
|
||||
logging.warning(msg)
|
||||
if not args.windows_standalone_build:
|
||||
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
|
||||
args.enable_manager = False
|
||||
|
||||
|
||||
@@ -184,6 +175,7 @@ execute_prestartup_script()
|
||||
|
||||
# Main code
|
||||
import asyncio
|
||||
import shutil
|
||||
import threading
|
||||
import gc
|
||||
|
||||
@@ -192,7 +184,6 @@ if 'torch' in sys.modules:
|
||||
|
||||
|
||||
import comfy.utils
|
||||
from app.assets.seeder import asset_seeder
|
||||
|
||||
import execution
|
||||
import server
|
||||
@@ -240,6 +231,24 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
"""Extract absolute file paths for output items from a history result."""
|
||||
paths = []
|
||||
base_dir = folder_paths.get_directory_by_type("output")
|
||||
for node_output in history_result.get("outputs", {}).values():
|
||||
for items in node_output.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict) or item.get("type") != "output":
|
||||
continue
|
||||
filename = item.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
paths.append(os.path.join(base_dir, item.get("subfolder", ""), filename))
|
||||
return paths
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
@@ -274,6 +283,7 @@ def prompt_worker(q, server_instance):
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
@@ -317,6 +327,11 @@ def prompt_worker(q, server_instance):
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
paths = _collect_output_absolute_paths(e.history_result)
|
||||
if register_output_files(paths, user_metadata={"prompt_id": prompt_id}) > 0:
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b2
|
||||
comfyui_manager==4.1b1
|
||||
|
||||
@@ -32,7 +32,7 @@ async def cache_control(
|
||||
)
|
||||
|
||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||
response.headers.setdefault("Cache-Control", "no-store")
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
|
||||
1
nodes.py
1
nodes.py
@@ -2450,7 +2450,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_painter.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.41.18
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-frontend-package==1.39.19
|
||||
comfyui-workflow-templates==0.9.11
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -22,8 +22,8 @@ alembic
|
||||
SQLAlchemy
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo>=0.2.10
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.2.9
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
@@ -310,7 +310,7 @@ class PromptServer():
|
||||
@routes.get("/")
|
||||
async def get_root(request):
|
||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||
response.headers['Cache-Control'] = 'no-store, must-revalidate'
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
return response
|
||||
|
||||
@@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
|
||||
},
|
||||
# JavaScript/CSS scenarios
|
||||
{
|
||||
"name": "js_no_store",
|
||||
"name": "js_no_cache",
|
||||
"path": "/script.js",
|
||||
"status": 200,
|
||||
"expected_cache": "no-store",
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "css_no_store",
|
||||
"name": "css_no_cache",
|
||||
"path": "/styles.css",
|
||||
"status": 200,
|
||||
"expected_cache": "no-store",
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "index_json_no_store",
|
||||
"name": "index_json_no_cache",
|
||||
"path": "/api/index.json",
|
||||
"status": 200,
|
||||
"expected_cache": "no-store",
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "localized_index_json_no_store",
|
||||
"name": "localized_index_json_no_cache",
|
||||
"path": "/templates/index.zh.json",
|
||||
"status": 200,
|
||||
"expected_cache": "no-store",
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
# Non-matching files
|
||||
|
||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def seeder():
|
||||
"""Fresh seeder instance for each test."""
|
||||
return _AssetSeeder()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reset_to_idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetToIdle:
|
||||
def test_sets_idle_and_clears_progress(self, seeder):
|
||||
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||
seeder._state = State.RUNNING
|
||||
seeder._progress = progress
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is progress
|
||||
|
||||
def test_noop_when_progress_already_none(self, seeder):
|
||||
"""_reset_to_idle should handle None progress gracefully."""
|
||||
seeder._state = State.CANCELLING
|
||||
seeder._progress = None
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – immediate start when idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichStartsImmediately:
|
||||
def test_starts_when_idle(self, seeder):
|
||||
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||
|
||||
def test_no_pending_when_started_immediately(self, seeder):
|
||||
"""No pending request should be stored when start_enrich succeeds."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True):
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – queuing when busy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichQueuesWhenBusy:
|
||||
def test_queues_when_busy(self, seeder):
|
||||
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
|
||||
assert result is False
|
||||
assert seeder._pending_enrich == {
|
||||
"roots": ("models",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – merging when a pending request already exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichMergesPending:
|
||||
def _make_busy(self, seeder):
|
||||
"""Patch start_enrich to always return False (seeder busy)."""
|
||||
return patch.object(seeder, "start_enrich", return_value=False)
|
||||
|
||||
def test_merges_roots(self, seeder):
|
||||
"""A second enqueue should merge roots with the existing pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",))
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "output"}
|
||||
|
||||
def test_merges_overlapping_roots(self, seeder):
|
||||
"""Duplicate roots should be deduplicated."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models", "input"))
|
||||
seeder.enqueue_enrich(roots=("input", "output"))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
|
||||
def test_compute_hashes_sticky_true(self, seeder):
|
||||
"""Once compute_hashes is True it should stay True after merging."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_stays_false(self, seeder):
|
||||
"""If both enqueues have compute_hashes=False it stays False."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is False
|
||||
|
||||
def test_triple_merge(self, seeder):
|
||||
"""Three successive enqueues should all merge correctly."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending enrich drains after scan completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPendingEnrichDrain:
|
||||
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||
"""After a fast scan finishes, the pending enrich should be started."""
|
||||
seeder = _AssetSeeder()
|
||||
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": True,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_called_once_with(
|
||||
roots=("output",),
|
||||
compute_hashes=True,
|
||||
)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||
seeder = _AssetSeeder()
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_no_drain_when_no_pending(self, *_mocks):
|
||||
"""start_enrich should not be called when there is no pending request."""
|
||||
seeder = _AssetSeeder()
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-safety of enqueue_enrich
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichThreadSafety:
|
||||
def test_concurrent_enqueues(self, seeder):
|
||||
"""Multiple threads enqueuing should not lose roots."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
barrier = threading.Barrier(3)
|
||||
|
||||
def enqueue(root):
|
||||
barrier.wait()
|
||||
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=enqueue, args=(r,))
|
||||
for r in ("models", "input", "output")
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
Reference in New Issue
Block a user