Compare commits

...

9 Commits

Author SHA1 Message Date
Jedrzej Kosinski
a145651cc0 Track custom node startup errors and expose via API endpoint
Store import and prestartup errors in NODE_STARTUP_ERRORS dict (nodes.py,
main.py) and add GET /custom_node_startup_errors endpoint (server.py) so
the frontend/Manager can distinguish failed imports from missing nodes.

Ref: ComfyUI-Launcher#303
Amp-Thread-ID: https://ampcode.com/threads/T-019d2346-6e6f-75e0-a97f-cdb6e26859f7
Co-authored-by: Amp <amp@ampcode.com>
2026-03-24 23:41:01 -07:00
Krishna Chaitanya
b53b10ea61 Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (#13145)
When training_dtype is set to "none" and the model's native dtype is
float16, GradScaler was unconditionally enabled. However, GradScaler
does not support bfloat16 gradients (only float16/float32), causing a
NotImplementedError when lora_dtype is "bf16" (the default).

Fix by only enabling GradScaler when LoRA parameters are not in
bfloat16, since bfloat16 has the same exponent range as float32 and
does not need gradient scaling to avoid underflow.

Fixes #13124
2026-03-24 23:53:44 -04:00
Luke Mino-Altherr
7d5534d8e5 feat(assets): register output files as assets after prompt execution (#12812) 2026-03-24 20:48:55 -07:00
Kohaku-Blueleaf
5ebb0c2e0b FP8 bwd training (#13121) 2026-03-24 20:39:04 -04:00
Dante
a0a64c679f Add Number Convert node (#13041)
* Add Number Convert node for unified numeric type conversion

Consolidates fragmented IntToFloat/FloatToInt nodes (previously only
available via third-party packs like ComfyMath, FillNodes, etc.) into
a single core node.

- Single input accepting INT, FLOAT, STRING, and BOOL types
- Two outputs: FLOAT and INT
- Conversion: bool→0/1, string→parsed number, float↔int standard cast
- Follows Math Expression node patterns (comfy_api, io.Schema, etc.)

Refs: COM-16925

* Register nodes_number_convert.py in extras_files list

Without this entry in nodes.py, the Number Convert node file
would not be discovered and loaded at startup.

* Add isfinite guard, exception chaining, and unit tests for Number Convert node

- Add math.isfinite() check to prevent int() crash on inf/nan string inputs
- Use 'from None' for cleaner exception chaining on string parse failure
- Add 21 unit tests covering all input types and error paths
2026-03-24 15:38:08 -07:00
Terry Jia
8e73678dae CURVE node (#12757)
* CURVE node

* remove curve to sigmas node

* feat: add CurveInput ABC with MonotoneCubicCurve implementation (#12986)

CurveInput is an abstract base class so future curve representations
(bezier, LUT-based, analytical functions) can be added without breaking
downstream nodes that type-check against CurveInput.

MonotoneCubicCurve is the concrete implementation that:
- Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly
- Pre-computes slopes as numpy arrays at construction time
- Provides vectorised interp_array() using numpy for batch evaluation
- interp() for single-value evaluation
- to_lut() for generating lookup tables

CurveEditor node wraps raw widget points in MonotoneCubicCurve.

* linear curve

* refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema

* feat: add HISTOGRAM type and histogram support to CurveEditor

* code improve

---------

Co-authored-by: Christian Byrne <cbyrne@comfy.org>
2026-03-24 17:47:28 -04:00
comfyanonymous
c2862b24af Update templates package version. (#13141) 2026-03-24 17:36:12 -04:00
Alexander Piskun
f9ec85f739 feat(api-nodes): update xAI Grok nodes (#13140) 2026-03-24 13:27:39 -07:00
Kelly Yang
2d5fd3f5dd fix: set default values of Color Adjustment node to zero (#13084)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-24 14:22:30 -04:00
30 changed files with 1608 additions and 38 deletions

View File

@@ -1,6 +1,7 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
create_stub_asset,
get_asset_by_hash,
get_existing_asset_ids,
reassign_asset_references,
@@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level,
count_active_siblings,
bulk_update_is_missing,
bulk_update_needs_verify,
convert_metadata_to_rows,
@@ -80,6 +82,8 @@ __all__ = [
"bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_enrichment_level",
"count_active_siblings",
"create_stub_asset",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"convert_metadata_to_rows",

View File

@@ -78,6 +78,18 @@ def upsert_asset(
return asset, created, updated
def create_stub_asset(
session: Session,
size_bytes: int,
mime_type: str | None = None,
) -> Asset:
"""Create a new asset with no hash (stub for later enrichment)."""
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
session.add(asset)
session.flush()
return asset
def bulk_insert_assets(
session: Session,
rows: list[dict],

View File

@@ -114,6 +114,23 @@ def get_reference_by_file_path(
)
def count_active_siblings(
session: Session,
asset_id: str,
exclude_reference_id: str,
) -> int:
"""Count active (non-deleted) references to an asset, excluding one reference."""
return (
session.query(AssetReference)
.filter(
AssetReference.asset_id == asset_id,
AssetReference.id != exclude_reference_id,
AssetReference.deleted_at.is_(None),
)
.count()
)
def reference_exists_for_asset_id(
session: Session,
asset_id: str,

View File

@@ -13,6 +13,7 @@ from app.assets.database.queries import (
delete_references_by_ids,
ensure_tags_exist,
get_asset_by_hash,
get_reference_by_id,
get_references_for_prefixes,
get_unenriched_references,
mark_references_missing_outside_prefixes,
@@ -338,6 +339,7 @@ def build_asset_specs(
"metadata": metadata,
"hash": asset_hash,
"mime_type": mime_type,
"job_id": None,
}
)
tag_pool.update(tags)
@@ -426,6 +428,7 @@ def enrich_asset(
except OSError:
return new_level
initial_mtime_ns = get_mtime_ns(stat_p)
rel_fname = compute_relative_filename(file_path)
mime_type: str | None = None
metadata = None
@@ -489,6 +492,18 @@ def enrich_asset(
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
# Optimistic guard: if the reference's mtime_ns changed since we
# started (e.g. ingest_existing_file updated it), our results are
# stale — discard them to avoid overwriting fresh registration data.
ref = get_reference_by_id(session, reference_id)
if ref is None or ref.mtime_ns != initial_mtime_ns:
session.rollback()
logging.info(
"Ref %s mtime changed during enrichment, discarding stale result",
reference_id,
)
return ENRICHMENT_STUB
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
set_reference_system_metadata(session, reference_id, system_metadata)

View File

@@ -77,7 +77,9 @@ class _AssetSeeder:
"""
def __init__(self) -> None:
self._lock = threading.Lock()
# RLock is required because _run_scan() drains pending work while
# holding _lock and re-enters start() which also acquires _lock.
self._lock = threading.RLock()
self._state = State.IDLE
self._progress: Progress | None = None
self._last_progress: Progress | None = None
@@ -92,6 +94,7 @@ class _AssetSeeder:
self._prune_first: bool = False
self._progress_callback: ProgressCallback | None = None
self._disabled: bool = False
self._pending_enrich: dict | None = None
def disable(self) -> None:
"""Disable the asset seeder, preventing any scans from starting."""
@@ -196,6 +199,42 @@ class _AssetSeeder:
compute_hashes=compute_hashes,
)
def enqueue_enrich(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
compute_hashes: bool = False,
) -> bool:
"""Start an enrichment scan now, or queue it for after the current scan.
If the seeder is idle, starts immediately. Otherwise, the enrich
request is stored and will run automatically when the current scan
finishes.
Args:
roots: Tuple of root types to scan
compute_hashes: If True, compute blake3 hashes
Returns:
True if started immediately, False if queued for later
"""
with self._lock:
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
return True
if self._pending_enrich is not None:
existing_roots = set(self._pending_enrich["roots"])
existing_roots.update(roots)
self._pending_enrich["roots"] = tuple(existing_roots)
self._pending_enrich["compute_hashes"] = (
self._pending_enrich["compute_hashes"] or compute_hashes
)
else:
self._pending_enrich = {
"roots": roots,
"compute_hashes": compute_hashes,
}
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
return False
def cancel(self) -> bool:
"""Request cancellation of the current scan.
@@ -381,9 +420,13 @@ class _AssetSeeder:
return marked
finally:
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
self._reset_to_idle()
def _reset_to_idle(self) -> None:
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
@@ -594,9 +637,18 @@ class _AssetSeeder:
},
)
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
self._reset_to_idle()
pending = self._pending_enrich
if pending is not None:
self._pending_enrich = None
if not self.start_enrich(
roots=pending["roots"],
compute_hashes=pending["compute_hashes"],
):
logging.warning(
"Pending enrich scan could not start (roots=%s)",
pending["roots"],
)
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
"""Run phase 1: fast scan to create stub records.

View File

@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
ingest_existing_file,
register_output_files,
upload_from_temp_path,
)
from app.assets.database.queries import (
@@ -72,6 +74,8 @@ __all__ = [
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"ingest_existing_file",
"register_output_files",
"get_mtime_ns",
"get_size_and_mtime_ns",
"list_assets_page",

View File

@@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
metadata: ExtractedMetadata | None
hash: str | None
mime_type: str | None
job_id: str | None
class AssetRow(TypedDict):
@@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
name: str
preview_id: str | None
user_metadata: dict[str, Any] | None
job_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
@@ -167,6 +169,7 @@ def batch_insert_seed_assets(
"name": spec["info_name"],
"preview_id": None,
"user_metadata": user_metadata,
"job_id": spec.get("job_id"),
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,

View File

@@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.queries import (
add_tags_to_reference,
count_active_siblings,
create_stub_asset,
ensure_tags_exist,
fetch_reference_and_asset,
get_asset_by_hash,
get_reference_by_file_path,
@@ -23,7 +26,8 @@ from app.assets.database.queries import (
upsert_reference,
validate_tags_exist,
)
from app.assets.helpers import normalize_tags
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_relative_filename,
@@ -130,6 +134,102 @@ def _ingest_file_from_path(
)
def register_output_files(
file_paths: Sequence[str],
user_metadata: UserMetadata = None,
job_id: str | None = None,
) -> int:
"""Register a batch of output file paths as assets.
Returns the number of files successfully registered.
"""
registered = 0
for abs_path in file_paths:
if not os.path.isfile(abs_path):
continue
try:
if ingest_existing_file(
abs_path, user_metadata=user_metadata, job_id=job_id
):
registered += 1
except Exception:
logging.exception("Failed to register output: %s", abs_path)
return registered
def ingest_existing_file(
abs_path: str,
user_metadata: UserMetadata = None,
extra_tags: Sequence[str] = (),
owner_id: str = "",
job_id: str | None = None,
) -> bool:
"""Register an existing on-disk file as an asset stub.
If a reference already exists for this path, updates mtime_ns, job_id,
size_bytes, and resets enrichment so the enricher will re-hash it.
For brand-new paths, inserts a stub record (hash=NULL) for immediate
UX visibility.
Returns True if a row was inserted or updated, False otherwise.
"""
locator = os.path.abspath(abs_path)
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
with create_session() as session:
existing_ref = get_reference_by_file_path(session, locator)
if existing_ref is not None:
now = get_utc_now()
existing_ref.mtime_ns = mtime_ns
existing_ref.job_id = job_id
existing_ref.is_missing = False
existing_ref.deleted_at = None
existing_ref.updated_at = now
existing_ref.enrichment_level = 0
asset = existing_ref.asset
if asset:
# If other refs share this asset, detach to a new stub
# instead of mutating the shared row.
siblings = count_active_siblings(session, asset.id, existing_ref.id)
if siblings > 0:
new_asset = create_stub_asset(
session,
size_bytes=size_bytes,
mime_type=mime_type or asset.mime_type,
)
existing_ref.asset_id = new_asset.id
else:
asset.hash = None
asset.size_bytes = size_bytes
if mime_type:
asset.mime_type = mime_type
session.commit()
return True
spec = {
"abs_path": abs_path,
"size_bytes": size_bytes,
"mtime_ns": mtime_ns,
"info_name": name,
"tags": tags,
"fname": os.path.basename(abs_path),
"metadata": None,
"hash": None,
"mime_type": mime_type,
"job_id": job_id,
}
if tags:
ensure_tags_exist(session, tags)
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
session.commit()
return result.won_paths > 0
def _register_existing_asset(
asset_hash: str,
name: str,

File diff suppressed because one or more lines are too long

View File

@@ -55,6 +55,7 @@ total_vram = 0
# Training Related State
in_training = False
training_fp8_bwd = False
def get_supported_float8_types():

View File

@@ -777,8 +777,16 @@ from .quant_ops import (
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
When training_fp8_bwd is enabled:
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
- Cached input is FP8 (half the memory of bf16)
When training_fp8_bwd is disabled:
- Forward: quantize input per layout, use quantized matmul
- Backward: dequantize weight to compute_dtype, use standard matmul
"""
@staticmethod
@@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
# Quantize input for forward (same layout as weight)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
@@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
# Unflatten output to match original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
# Save for backward
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
if ctx.fp8_bwd:
# Cache FP8 quantized input — half the memory of bf16
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
ctx.q_input = q_input # already FP8, reuse
else:
# NVFP4 or other layout — quantize input to FP8 for backward
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
ctx.save_for_backward(weight)
else:
ctx.q_input = None
ctx.save_for_backward(input_float, weight)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
# Value casting — only difference between fp8 and non-fp8 paths
if ctx.fp8_bwd:
weight, = ctx.saved_tensors
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
weight_mm = weight
elif isinstance(weight, QuantizedTensor):
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
else:
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
input_mm = ctx.q_input
else:
weight_f = weight.to(compute_dtype)
input_float, weight = ctx.saved_tensors
# Standard tensors → torch.mm does regular matmul
grad_mm = grad_2d
if isinstance(weight, QuantizedTensor):
weight_mm = weight.dequantize().to(compute_dtype)
else:
weight_mm = weight.to(compute_dtype)
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
# Computation — same for both paths, dispatch handles the rest
grad_input = torch.mm(grad_mm, weight_mm)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
grad_weight = torch.mm(grad_mm.t(), input_mm)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)

View File

@@ -5,6 +5,10 @@ from comfy_api.latest._input import (
MaskInput,
LatentInput,
VideoInput,
CurvePoint,
CurveInput,
MonotoneCubicCurve,
LinearCurve,
)
__all__ = [
@@ -13,4 +17,8 @@ __all__ = [
"MaskInput",
"LatentInput",
"VideoInput",
"CurvePoint",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -1,4 +1,5 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .video_types import VideoInput
__all__ = [
@@ -7,4 +8,8 @@ __all__ = [
"VideoInput",
"MaskInput",
"LatentInput",
"CurvePoint",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -0,0 +1,219 @@
from __future__ import annotations
import logging
import math
from abc import ABC, abstractmethod
import numpy as np
logger = logging.getLogger(__name__)
CurvePoint = tuple[float, float]
class CurveInput(ABC):
"""Abstract base class for curve inputs.
Subclasses represent different curve representations (control-point
interpolation, analytical functions, LUT-based, etc.) while exposing a
uniform evaluation interface to downstream nodes.
"""
@property
@abstractmethod
def points(self) -> list[CurvePoint]:
"""The control points that define this curve."""
@abstractmethod
def interp(self, x: float) -> float:
"""Evaluate the curve at a single *x* value in [0, 1]."""
def interp_array(self, xs: np.ndarray) -> np.ndarray:
"""Vectorised evaluation over a numpy array of x values.
Subclasses should override this for better performance. The default
falls back to scalar ``interp`` calls.
"""
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
def to_lut(self, size: int = 256) -> np.ndarray:
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
return self.interp_array(np.linspace(0.0, 1.0, size))
@staticmethod
def from_raw(data) -> CurveInput:
"""Convert raw curve data (dict or point list) to a CurveInput instance.
Accepts:
- A ``CurveInput`` instance (returned as-is).
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
"""
if isinstance(data, CurveInput):
return data
if isinstance(data, dict):
raw_points = data["points"]
interpolation = data.get("interpolation", "monotone_cubic")
else:
raw_points = data
interpolation = "monotone_cubic"
points = [(float(x), float(y)) for x, y in raw_points]
if interpolation == "linear":
return LinearCurve(points)
if interpolation != "monotone_cubic":
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
return MonotoneCubicCurve(points)
class MonotoneCubicCurve(CurveInput):
"""Monotone cubic Hermite interpolation over control points.
Mirrors the frontend ``createMonotoneInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
backend evaluation matches the editor preview exactly.
All heavy work (sorting, slope computation) happens once at construction.
``interp_array`` is fully vectorised with numpy.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
self._slopes = self._compute_slopes()
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def _compute_slopes(self) -> np.ndarray:
xs, ys = self._xs, self._ys
n = len(xs)
if n < 2:
return np.zeros(n, dtype=np.float64)
dx = np.diff(xs)
dy = np.diff(ys)
dx_safe = np.where(dx == 0, 1.0, dx)
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
slopes = np.empty(n, dtype=np.float64)
slopes[0] = deltas[0]
slopes[-1] = deltas[-1]
for i in range(1, n - 1):
if deltas[i - 1] * deltas[i] <= 0:
slopes[i] = 0.0
else:
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
for i in range(n - 1):
if deltas[i] == 0:
slopes[i] = 0.0
slopes[i + 1] = 0.0
else:
alpha = slopes[i] / deltas[i]
beta = slopes[i + 1] / deltas[i]
s = alpha * alpha + beta * beta
if s > 9:
t = 3 / math.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
return slopes
def interp(self, x: float) -> float:
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
if x <= xs[0]:
return float(ys[0])
if x >= xs[-1]:
return float(ys[-1])
hi = int(np.searchsorted(xs, x, side='right'))
hi = min(hi, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
if dx == 0:
return float(ys[lo])
t = (x - xs[lo]) / dx
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
"""Fully vectorised evaluation using numpy."""
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if n == 1:
return np.full_like(xs_in, ys[0], dtype=np.float64)
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
dx_safe = np.where(dx == 0, 1.0, dx)
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
result = np.where(xs_in <= xs[0], ys[0], result)
result = np.where(xs_in >= xs[-1], ys[-1], result)
return result
def __repr__(self) -> str:
return f"MonotoneCubicCurve(points={self._points})"
class LinearCurve(CurveInput):
"""Piecewise linear interpolation over control points.
Mirrors the frontend ``createLinearInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def interp(self, x: float) -> float:
xs, ys = self._xs, self._ys
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
return float(np.interp(x, xs, ys))
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
if len(self._xs) == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if len(self._xs) == 1:
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
return np.interp(xs_in, self._xs, self._ys)
def __repr__(self) -> str:
return f"LinearCurve(points={self._points})"

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from comfy.samplers import CFGGuider, Sampler
from comfy.sd import CLIP, VAE
from comfy.sd import StyleModel as StyleModel_
from comfy_api.input import VideoInput
from comfy_api.input import VideoInput, CurveInput as CurveInput_
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
@@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
CurvePoint = tuple[float, float]
Type = list[CurvePoint]
from comfy_api.input import CurvePoint
if TYPE_CHECKING:
Type = CurveInput_
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
@@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]
def as_dict(self):
d = super().as_dict()
if self.default is not None:
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
return d
@comfytype(io_type="HISTOGRAM")
class Histogram(ComfyTypeIO):
"""A histogram represented as a list of bin counts."""
Type = list[int]
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
@@ -2240,5 +2253,6 @@ __all__ = [
"PriceBadge",
"BoundingBox",
"Curve",
"Histogram",
"NodeReplace",
]

View File

@@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
class VideoGenerationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
image: InputUrlObject | None = Field(...)
image: InputUrlObject | None = Field(None)
reference_images: list[InputUrlObject] | None = Field(None)
duration: int = Field(...)
aspect_ratio: str | None = Field(...)
resolution: str = Field(...)
seed: int = Field(...)
class VideoExtensionRequest(BaseModel):
prompt: str = Field(...)
video: InputUrlObject = Field(...)
duration: int = Field(default=6)
model: str | None = Field(default=None)
class VideoEditRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)

View File

@@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
ImageGenerationResponse,
InputUrlObject,
VideoEditRequest,
VideoExtensionRequest,
VideoGenerationRequest,
VideoGenerationResponse,
VideoStatusResponse,
@@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
tensor_to_base64_string,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
@@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
return None
def _extract_grok_video_price(response) -> float | None:
price = _extract_grok_price(response)
if price is not None:
return price * 1.43
return None
class GrokImageNode(IO.ComfyNode):
@classmethod
@@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
seed: int,
image: Input.Image | None = None,
) -> IO.NodeOutput:
if model == "grok-imagine-video-beta":
model = "grok-imagine-video"
image_url = None
if image is not None:
if get_number_of_images(image) != 1:
@@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokVideoReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video",
category="api node/video/Grok",
description="Generate video guided by reference images as style and content references.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of the desired video.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="reference_",
min=1,
max=7,
),
tooltip="Up to 7 reference images to guide the video generation.",
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="The resolution of the output video.",
),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
tooltip="The aspect ratio of the output video.",
),
IO.Int.Input(
"duration",
default=6,
min=2,
max=10,
step=1,
tooltip="The duration of the output video in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model.duration", "model.resolution"],
input_groups=["model.reference_images"],
),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$refs := inputGroups["model.reference_images"];
$rate := $res = "720p" ? 0.07 : 0.05;
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
{"type":"usd","usd": $price}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
ref_image_urls = await upload_images_to_comfyapi(
cls,
list(model["reference_images"].values()),
mime_type="image/png",
wait_label="Uploading base images",
max_images=7,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest(
model=model["model"],
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
prompt=prompt,
resolution=model["resolution"],
duration=model["duration"],
aspect_ratio=model["aspect_ratio"],
seed=seed,
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokVideoExtendNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoExtendNode",
display_name="Grok Video Extend",
category="api node/video/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of what should happen next in the video.",
),
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Int.Input(
"duration",
default=8,
min=2,
max=10,
step=1,
tooltip="Length of the extension in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video extension.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
expr="""
(
$dur := $lookup(widgets, "model.duration");
{
"type": "range_usd",
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
"max_usd": (0.15 + 0.05 * $dur) * 1.43
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
video: Input.Video,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
validate_video_duration(video, min_duration=2, max_duration=15)
video_size = get_fs_object_size(video.get_stream_source())
if video_size > 50 * 1024 * 1024:
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
data=VideoExtensionRequest(
prompt=prompt,
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
duration=model["duration"],
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
class GrokExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
GrokImageNode,
GrokImageEditNode,
GrokVideoNode,
GrokVideoReferenceNode,
GrokVideoEditNode,
GrokVideoExtendNode,
]

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from comfy_api.latest import ComfyExtension, io
from comfy_api.input import CurveInput
from typing_extensions import override
class CurveEditor(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CurveEditor",
display_name="Curve Editor",
category="utils",
inputs=[
io.Curve.Input("curve"),
io.Histogram.Input("histogram", optional=True),
],
outputs=[
io.Curve.Output("curve"),
],
)
@classmethod
def execute(cls, curve, histogram=None) -> io.NodeOutput:
result = CurveInput.from_raw(curve)
ui = {}
if histogram is not None:
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
class CurveExtension(ComfyExtension):
@override
async def get_node_list(self):
return [CurveEditor]
async def comfy_entrypoint():
return CurveExtension()

View File

@@ -0,0 +1,79 @@
"""Number Convert node for unified numeric type conversion.
Provides a single node that converts INT, FLOAT, STRING, and BOOL
inputs into FLOAT and INT outputs.
"""
from __future__ import annotations
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class NumberConvertNode(io.ComfyNode):
"""Converts various types to numeric FLOAT and INT outputs."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ComfyNumberConvert",
display_name="Number Convert",
category="math",
search_aliases=[
"int to float", "float to int", "number convert",
"int2float", "float2int", "cast", "parse number",
"string to number", "bool to int",
],
inputs=[
io.MultiType.Input(
"value",
[io.Int, io.Float, io.String, io.Boolean],
display_name="value",
),
],
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
],
)
@classmethod
def execute(cls, value) -> io.NodeOutput:
if isinstance(value, bool):
float_val = 1.0 if value else 0.0
elif isinstance(value, (int, float)):
float_val = float(value)
elif isinstance(value, str):
text = value.strip()
if not text:
raise ValueError("Cannot convert empty string to number.")
try:
float_val = float(text)
except ValueError:
raise ValueError(
f"Cannot convert string to number: {value!r}"
) from None
else:
raise TypeError(
f"Unsupported input type: {type(value).__name__}"
)
if not math.isfinite(float_val):
raise ValueError(
f"Cannot convert non-finite value to number: {float_val}"
)
return io.NodeOutput(float_val, int(float_val))
class NumberConvertExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [NumberConvertNode]
async def comfy_entrypoint() -> NumberConvertExtension:
return NumberConvertExtension()

View File

@@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
default="bf16",
tooltip="The dtype to use for lora.",
),
io.Boolean.Input(
"quantized_backward",
default=False,
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
),
io.Combo.Input(
"algorithm",
options=list(adapter_maps.keys()),
@@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
seed,
training_dtype,
lora_dtype,
quantized_backward,
algorithm,
gradient_checkpointing,
checkpoint_depth,
@@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
quantized_backward = quantized_backward[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
@@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
comfy.model_management.training_fp8_bwd = quantized_backward
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)
@@ -1137,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype
mp = model.clone()
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
@@ -1145,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode):
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
@@ -1156,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode):
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16

52
main.py
View File

@@ -9,6 +9,8 @@ import folder_paths
import time
from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
from app.assets.seeder import asset_seeder
from app.assets.services import register_output_files
import itertools
import utils.extra_config
from utils.mime_types import init_mime_types
@@ -137,7 +139,16 @@ def execute_prestartup_script():
spec.loader.exec_module(module)
return True
except Exception as e:
import traceback
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
from nodes import NODE_STARTUP_ERRORS, get_module_name
node_module_name = get_module_name(os.path.dirname(script_path))
NODE_STARTUP_ERRORS[node_module_name] = {
"module_path": os.path.dirname(script_path),
"error": str(e),
"traceback": traceback.format_exc(),
"phase": "prestartup",
}
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")
@@ -192,7 +203,6 @@ if 'torch' in sys.modules:
import comfy.utils
from app.assets.seeder import asset_seeder
import execution
import server
@@ -240,6 +250,38 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
"""Extract absolute file paths for output items from a history result."""
paths: list[str] = []
seen: set[str] = set()
for node_output in history_result.get("outputs", {}).values():
for items in node_output.values():
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type not in ("output", "temp"):
continue
base_dir = folder_paths.get_directory_by_type(item_type)
if base_dir is None:
continue
base_dir = os.path.abspath(base_dir)
filename = item.get("filename")
if not filename:
continue
abs_path = os.path.abspath(
os.path.join(base_dir, item.get("subfolder", ""), filename)
)
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
continue
if abs_path not in seen:
seen.add(abs_path)
paths.append(abs_path)
return paths
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_type = execution.CacheType.CLASSIC
@@ -274,6 +316,7 @@ def prompt_worker(q, server_instance):
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
@@ -296,6 +339,10 @@ def prompt_worker(q, server_instance):
else:
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
if not asset_seeder.is_disabled():
paths = _collect_output_absolute_paths(e.history_result)
register_output_files(paths, job_id=prompt_id)
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
@@ -317,6 +364,9 @@ def prompt_worker(q, server_instance):
last_gc_collect = current_time
need_gc = False
hook_breaker_ac10a0.restore_functions()
if not asset_seeder.is_disabled():
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
asset_seeder.resume()

View File

@@ -2181,6 +2181,9 @@ EXTENSION_WEB_DIRS = {}
# Dictionary of successfully loaded module names and associated directories.
LOADED_MODULE_DIRS = {}
# Dictionary of custom node startup errors, keyed by module name.
NODE_STARTUP_ERRORS: dict[str, dict] = {}
def get_module_name(module_path: str) -> str:
"""
@@ -2298,6 +2301,13 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
except Exception as e:
logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
module_name = get_module_name(module_path)
NODE_STARTUP_ERRORS[module_name] = {
"module_path": module_path,
"error": str(e),
"traceback": traceback.format_exc(),
"phase": "import",
}
return False
async def init_external_custom_nodes():
@@ -2454,7 +2464,9 @@ async def init_builtin_extra_nodes():
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
]
import_failed = []

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.8
comfyui-workflow-templates==0.9.26
comfyui-workflow-templates==0.9.36
comfyui-embedded-docs==0.4.3
torch
torchsde

View File

@@ -753,6 +753,10 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/custom_node_startup_errors")
async def get_custom_node_startup_errors(request):
return web.json_response(nodes.NODE_STARTUP_ERRORS)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@@ -23,6 +23,21 @@ def db_engine():
return engine
@pytest.fixture
def db_engine_fk():
"""In-memory SQLite engine with foreign key enforcement enabled."""
engine = create_engine("sqlite:///:memory:")
@event.listens_for(engine, "connect")
def _set_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session(db_engine):
"""Session fixture for tests that need direct DB access."""

View File

@@ -1,9 +1,11 @@
"""Tests for asset enrichment (mime_type and hash population)."""
import os
from pathlib import Path
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.services.file_utils import get_mtime_ns
from app.assets.scanner import (
ENRICHMENT_HASHED,
ENRICHMENT_METADATA,
@@ -20,6 +22,13 @@ def _create_stub_asset(
name: str | None = None,
) -> tuple[Asset, AssetReference]:
"""Create a stub asset with reference for testing enrichment."""
# Use the real file's mtime so the optimistic guard in enrich_asset passes
try:
stat_result = os.stat(file_path, follow_symlinks=True)
mtime_ns = get_mtime_ns(stat_result)
except OSError:
mtime_ns = 1234567890000000000
asset = Asset(
id=asset_id,
hash=None,
@@ -35,7 +44,7 @@ def _create_stub_asset(
name=name or f"test-asset-{asset_id}",
owner_id="system",
file_path=file_path,
mtime_ns=1234567890000000000,
mtime_ns=mtime_ns,
enrichment_level=ENRICHMENT_STUB,
)
session.add(ref)

View File

@@ -1,12 +1,18 @@
"""Tests for ingest services."""
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session as SASession, Session
from app.assets.database.models import Asset, AssetReference, Tag
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
from app.assets.database.queries import get_reference_tags
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
from app.assets.services.ingest import (
_ingest_file_from_path,
_register_existing_asset,
ingest_existing_file,
)
class TestIngestFileFromPath:
@@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
assert result.created is True
assert set(result.tags) == {"alpha", "beta"}
class TestIngestExistingFileTagFK:
"""Regression: ingest_existing_file must seed Tag rows before inserting
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
before they can be referenced in asset_reference_tags."""
@contextmanager
def _create_session():
with SASession(db_engine_fk) as sess:
yield sess
file_path = temp_dir / "output.png"
file_path.write_bytes(b"image data")
with patch("app.assets.services.ingest.create_session", _create_session), \
patch(
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
return_value=("output.png", ["output"]),
):
result = ingest_existing_file(
abs_path=str(file_path),
extra_tags=["my-job"],
)
assert result is True
with SASession(db_engine_fk) as sess:
tag_names = {t.name for t in sess.query(Tag).all()}
assert "output" in tag_names
assert "my-job" in tag_names
ref_tags = sess.query(AssetReferenceTag).all()
ref_tag_names = {rt.tag_name for rt in ref_tags}
assert "output" in ref_tag_names
assert "my-job" in ref_tag_names

View File

@@ -0,0 +1,123 @@
import pytest
from unittest.mock import patch, MagicMock
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
mock_server = MagicMock()
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
from comfy_extras.nodes_number_convert import NumberConvertNode
class TestNumberConvertExecute:
@staticmethod
def _exec(value) -> object:
return NumberConvertNode.execute(value)
# --- INT input ---
def test_int_input(self):
result = self._exec(42)
assert result[0] == 42.0
assert result[1] == 42
def test_int_zero(self):
result = self._exec(0)
assert result[0] == 0.0
assert result[1] == 0
def test_int_negative(self):
result = self._exec(-7)
assert result[0] == -7.0
assert result[1] == -7
# --- FLOAT input ---
def test_float_input(self):
result = self._exec(3.14)
assert result[0] == 3.14
assert result[1] == 3
def test_float_truncation_toward_zero(self):
result = self._exec(-2.9)
assert result[0] == -2.9
assert result[1] == -2 # int() truncates toward zero, not floor
def test_float_output_type(self):
result = self._exec(5)
assert isinstance(result[0], float)
def test_int_output_type(self):
result = self._exec(5.7)
assert isinstance(result[1], int)
# --- BOOL input ---
def test_bool_true(self):
result = self._exec(True)
assert result[0] == 1.0
assert result[1] == 1
def test_bool_false(self):
result = self._exec(False)
assert result[0] == 0.0
assert result[1] == 0
# --- STRING input ---
def test_string_integer(self):
result = self._exec("42")
assert result[0] == 42.0
assert result[1] == 42
def test_string_float(self):
result = self._exec("3.14")
assert result[0] == 3.14
assert result[1] == 3
def test_string_negative(self):
result = self._exec("-5.5")
assert result[0] == -5.5
assert result[1] == -5
def test_string_with_whitespace(self):
result = self._exec(" 7.0 ")
assert result[0] == 7.0
assert result[1] == 7
def test_string_scientific_notation(self):
result = self._exec("1e3")
assert result[0] == 1000.0
assert result[1] == 1000
# --- STRING error paths ---
def test_empty_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert empty string"):
self._exec("")
def test_whitespace_only_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert empty string"):
self._exec(" ")
def test_non_numeric_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert string to number"):
self._exec("abc")
def test_string_inf_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("inf")
def test_string_nan_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("nan")
def test_string_negative_inf_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("-inf")
# --- Unsupported type ---
def test_unsupported_type_raises(self):
with pytest.raises(TypeError, match="Unsupported input type"):
self._exec([1, 2, 3])

View File

@@ -1,6 +1,7 @@
"""Unit tests for the _AssetSeeder background scanning class."""
import threading
import time
from unittest.mock import patch
import pytest
@@ -771,6 +772,188 @@ class TestSeederStopRestart:
assert collected_roots[1] == ("input",)
class TestEnqueueEnrichHandoff:
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
def test_pending_enrich_runs_after_scan_completes(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""A queued enrich request runs automatically when a scan finishes."""
enrich_roots_seen: list[tuple] = []
original_start = fresh_seeder.start
def tracking_start(*args, **kwargs):
phase = kwargs.get("phase")
roots = kwargs.get("roots", args[0] if args else None)
result = original_start(*args, **kwargs)
if phase == ScanPhase.ENRICH and result:
enrich_roots_seen.append(roots)
return result
fresh_seeder.start = tracking_start
# Start a fast scan, then enqueue an enrich while it's running
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=True
)
assert queued is False # queued, not started immediately
barrier.set()
# Wait for the original scan + the auto-started enrich scan
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
assert enrich_roots_seen == [("input",)]
def test_enqueue_enrich_during_drain_does_not_lose_work(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""enqueue_enrich called concurrently with drain cannot drop work.
Simulates the race: another thread calls enqueue_enrich right as the
scan thread is draining _pending_enrich. The enqueue must either be
picked up by the draining scan or successfully start its own scan.
"""
barrier = threading.Event()
reached = threading.Event()
enrich_started = threading.Event()
enrich_call_count = 0
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
# Track how many times start_enrich actually fires
real_start_enrich = fresh_seeder.start_enrich
enrich_roots_seen: list[tuple] = []
def tracking_start_enrich(**kwargs):
nonlocal enrich_call_count
enrich_call_count += 1
enrich_roots_seen.append(kwargs.get("roots"))
result = real_start_enrich(**kwargs)
if result:
enrich_started.set()
return result
fresh_seeder.start_enrich = tracking_start_enrich
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
# Start a scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
# Queue an enrich while scan is running
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
# Let scan finish — drain will fire start_enrich atomically
barrier.set()
# Wait for drain to complete and the enrich scan to start
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
assert ("output",) in enrich_roots_seen
def test_concurrent_enqueue_during_drain_not_lost(
self, fresh_seeder: _AssetSeeder,
):
"""A second enqueue_enrich arriving while drain is in progress is not lost.
Because the drain now holds _lock through the start_enrich call,
a concurrent enqueue_enrich will block until start_enrich has
transitioned state to RUNNING, then the enqueue will queue its
payload as _pending_enrich for the *next* drain.
"""
scan_barrier = threading.Event()
scan_reached = threading.Event()
enrich_barrier = threading.Event()
enrich_reached = threading.Event()
collect_call = 0
def gated_collect(*args):
nonlocal collect_call
collect_call += 1
if collect_call == 1:
# First call: the initial fast scan
scan_reached.set()
scan_barrier.wait(timeout=5.0)
return []
enrich_call = 0
def gated_get_unenriched(*args, **kwargs):
nonlocal enrich_call
enrich_call += 1
if enrich_call == 1:
# First enrich batch: signal and block
enrich_reached.set()
enrich_barrier.wait(timeout=5.0)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
# 1. Start fast scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert scan_reached.wait(timeout=2.0)
# 2. Queue enrich while fast scan is running
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=False
)
assert queued is False
# 3. Let the fast scan finish — drain will start the enrich scan
scan_barrier.set()
# 4. Wait until the drained enrich scan is running
assert enrich_reached.wait(timeout=5.0)
# 5. Now enqueue another enrich while the drained scan is running
queued2 = fresh_seeder.enqueue_enrich(
roots=("output",), compute_hashes=True
)
assert queued2 is False # should be queued, not started
# Verify _pending_enrich was set (the second enqueue was captured)
with fresh_seeder._lock:
assert fresh_seeder._pending_enrich is not None
assert "output" in fresh_seeder._pending_enrich["roots"]
# Let the enrich scan finish
enrich_barrier.set()
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
return UnenrichedReferenceRow(
reference_id=ref_id, asset_id=asset_id,

250
tests/test_asset_seeder.py Normal file
View File

@@ -0,0 +1,250 @@
"""Tests for app.assets.seeder enqueue_enrich and pending-queue behaviour."""
import threading
from unittest.mock import patch
import pytest
from app.assets.seeder import Progress, _AssetSeeder, State
@pytest.fixture()
def seeder():
"""Fresh seeder instance for each test."""
return _AssetSeeder()
# ---------------------------------------------------------------------------
# _reset_to_idle
# ---------------------------------------------------------------------------
class TestResetToIdle:
def test_sets_idle_and_clears_progress(self, seeder):
"""_reset_to_idle should move state to IDLE and snapshot progress."""
progress = Progress(scanned=10, total=20, created=5, skipped=3)
seeder._state = State.RUNNING
seeder._progress = progress
with seeder._lock:
seeder._reset_to_idle()
assert seeder._state is State.IDLE
assert seeder._progress is None
assert seeder._last_progress is progress
def test_noop_when_progress_already_none(self, seeder):
"""_reset_to_idle should handle None progress gracefully."""
seeder._state = State.CANCELLING
seeder._progress = None
with seeder._lock:
seeder._reset_to_idle()
assert seeder._state is State.IDLE
assert seeder._progress is None
assert seeder._last_progress is None
# ---------------------------------------------------------------------------
# enqueue_enrich immediate start when idle
# ---------------------------------------------------------------------------
class TestEnqueueEnrichStartsImmediately:
def test_starts_when_idle(self, seeder):
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
with patch.object(seeder, "start_enrich", return_value=True) as mock:
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
def test_no_pending_when_started_immediately(self, seeder):
"""No pending request should be stored when start_enrich succeeds."""
with patch.object(seeder, "start_enrich", return_value=True):
seeder.enqueue_enrich(roots=("output",))
assert seeder._pending_enrich is None
# ---------------------------------------------------------------------------
# enqueue_enrich queuing when busy
# ---------------------------------------------------------------------------
class TestEnqueueEnrichQueuesWhenBusy:
def test_queues_when_busy(self, seeder):
"""enqueue_enrich should store a pending request when seeder is busy."""
with patch.object(seeder, "start_enrich", return_value=False):
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
assert result is False
assert seeder._pending_enrich == {
"roots": ("models",),
"compute_hashes": False,
}
def test_queues_preserves_compute_hashes_true(self, seeder):
with patch.object(seeder, "start_enrich", return_value=False):
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
assert seeder._pending_enrich["compute_hashes"] is True
# ---------------------------------------------------------------------------
# enqueue_enrich merging when a pending request already exists
# ---------------------------------------------------------------------------
class TestEnqueueEnrichMergesPending:
def _make_busy(self, seeder):
"""Patch start_enrich to always return False (seeder busy)."""
return patch.object(seeder, "start_enrich", return_value=False)
def test_merges_roots(self, seeder):
"""A second enqueue should merge roots with the existing pending request."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",))
seeder.enqueue_enrich(roots=("output",))
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "output"}
def test_merges_overlapping_roots(self, seeder):
"""Duplicate roots should be deduplicated."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models", "input"))
seeder.enqueue_enrich(roots=("input", "output"))
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "input", "output"}
def test_compute_hashes_sticky_true(self, seeder):
"""Once compute_hashes is True it should stay True after merging."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
assert seeder._pending_enrich["compute_hashes"] is True
def test_compute_hashes_upgrades_to_true(self, seeder):
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
assert seeder._pending_enrich["compute_hashes"] is True
def test_compute_hashes_stays_false(self, seeder):
"""If both enqueues have compute_hashes=False it stays False."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
assert seeder._pending_enrich["compute_hashes"] is False
def test_triple_merge(self, seeder):
"""Three successive enqueues should all merge correctly."""
with self._make_busy(seeder):
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "input", "output"}
assert seeder._pending_enrich["compute_hashes"] is True
# ---------------------------------------------------------------------------
# Pending enrich drains after scan completes
# ---------------------------------------------------------------------------
class TestPendingEnrichDrain:
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
@patch("app.assets.seeder.dependencies_available", return_value=True)
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
@patch("app.assets.seeder.sync_root_safely", return_value=set())
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
def test_pending_enrich_starts_after_scan(self, *_mocks):
"""After a fast scan finishes, the pending enrich should be started."""
seeder = _AssetSeeder()
seeder._pending_enrich = {
"roots": ("output",),
"compute_hashes": True,
}
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
seeder.start_fast(roots=("models",))
seeder.wait(timeout=5)
mock_start.assert_called_once_with(
roots=("output",),
compute_hashes=True,
)
assert seeder._pending_enrich is None
@patch("app.assets.seeder.dependencies_available", return_value=True)
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
@patch("app.assets.seeder.sync_root_safely", return_value=set())
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
def test_pending_cleared_even_when_start_fails(self, *_mocks):
"""_pending_enrich should be cleared even if start_enrich returns False."""
seeder = _AssetSeeder()
seeder._pending_enrich = {
"roots": ("output",),
"compute_hashes": False,
}
with patch.object(seeder, "start_enrich", return_value=False):
seeder.start_fast(roots=("models",))
seeder.wait(timeout=5)
assert seeder._pending_enrich is None
@patch("app.assets.seeder.dependencies_available", return_value=True)
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
@patch("app.assets.seeder.sync_root_safely", return_value=set())
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
def test_no_drain_when_no_pending(self, *_mocks):
"""start_enrich should not be called when there is no pending request."""
seeder = _AssetSeeder()
assert seeder._pending_enrich is None
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
seeder.start_fast(roots=("models",))
seeder.wait(timeout=5)
mock_start.assert_not_called()
# ---------------------------------------------------------------------------
# Thread-safety of enqueue_enrich
# ---------------------------------------------------------------------------
class TestEnqueueEnrichThreadSafety:
def test_concurrent_enqueues(self, seeder):
"""Multiple threads enqueuing should not lose roots."""
with patch.object(seeder, "start_enrich", return_value=False):
barrier = threading.Barrier(3)
def enqueue(root):
barrier.wait()
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
threads = [
threading.Thread(target=enqueue, args=(r,))
for r in ("models", "input", "output")
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=5)
merged = set(seeder._pending_enrich["roots"])
assert merged == {"models", "input", "output"}