mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-25 06:57:29 +00:00
Compare commits
7 Commits
feature/cu
...
pysssss/an
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54ff5464bd | ||
|
|
333ff2e8a0 | ||
|
|
c821d8ee2a | ||
|
|
27b6f8a927 | ||
|
|
9ad848bd59 | ||
|
|
efe6439ad0 | ||
|
|
8d76bb94fd |
@@ -1,7 +1,6 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
create_stub_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
@@ -13,7 +12,6 @@ from app.assets.database.queries.asset_reference import (
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
count_active_siblings,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
@@ -82,8 +80,6 @@ __all__ = [
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"count_active_siblings",
|
||||
"create_stub_asset",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
|
||||
@@ -78,18 +78,6 @@ def upsert_asset(
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def create_stub_asset(
|
||||
session: Session,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> Asset:
|
||||
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
|
||||
@@ -114,23 +114,6 @@ def get_reference_by_file_path(
|
||||
)
|
||||
|
||||
|
||||
def count_active_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
exclude_reference_id: str,
|
||||
) -> int:
|
||||
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||
return (
|
||||
session.query(AssetReference)
|
||||
.filter(
|
||||
AssetReference.asset_id == asset_id,
|
||||
AssetReference.id != exclude_reference_id,
|
||||
AssetReference.deleted_at.is_(None),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
def reference_exists_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.assets.database.queries import (
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
@@ -339,7 +338,6 @@ def build_asset_specs(
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
"job_id": None,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
@@ -428,7 +426,6 @@ def enrich_asset(
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
@@ -492,18 +489,6 @@ def enrich_asset(
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||
# started (e.g. ingest_existing_file updated it), our results are
|
||||
# stale — discard them to avoid overwriting fresh registration data.
|
||||
ref = get_reference_by_id(session, reference_id)
|
||||
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||
session.rollback()
|
||||
logging.info(
|
||||
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||
reference_id,
|
||||
)
|
||||
return ENRICHMENT_STUB
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
@@ -77,9 +77,7 @@ class _AssetSeeder:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# RLock is required because _run_scan() drains pending work while
|
||||
# holding _lock and re-enters start() which also acquires _lock.
|
||||
self._lock = threading.RLock()
|
||||
self._lock = threading.Lock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
@@ -94,7 +92,6 @@ class _AssetSeeder:
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
self._pending_enrich: dict | None = None
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
@@ -199,42 +196,6 @@ class _AssetSeeder:
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def enqueue_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||
|
||||
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||
request is stored and will run automatically when the current scan
|
||||
finishes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if started immediately, False if queued for later
|
||||
"""
|
||||
with self._lock:
|
||||
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||
return True
|
||||
if self._pending_enrich is not None:
|
||||
existing_roots = set(self._pending_enrich["roots"])
|
||||
existing_roots.update(roots)
|
||||
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||
self._pending_enrich["compute_hashes"] = (
|
||||
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||
)
|
||||
else:
|
||||
self._pending_enrich = {
|
||||
"roots": roots,
|
||||
"compute_hashes": compute_hashes,
|
||||
}
|
||||
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||
return False
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
@@ -420,13 +381,9 @@ class _AssetSeeder:
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._reset_to_idle()
|
||||
|
||||
def _reset_to_idle(self) -> None:
|
||||
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
@@ -637,18 +594,9 @@ class _AssetSeeder:
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._reset_to_idle()
|
||||
pending = self._pending_enrich
|
||||
if pending is not None:
|
||||
self._pending_enrich = None
|
||||
if not self.start_enrich(
|
||||
roots=pending["roots"],
|
||||
compute_hashes=pending["compute_hashes"],
|
||||
):
|
||||
logging.warning(
|
||||
"Pending enrich scan could not start (roots=%s)",
|
||||
pending["roots"],
|
||||
)
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
@@ -23,8 +23,6 @@ from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
ingest_existing_file,
|
||||
register_output_files,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
@@ -74,8 +72,6 @@ __all__ = [
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"ingest_existing_file",
|
||||
"register_output_files",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
|
||||
@@ -37,7 +37,6 @@ class SeedAssetSpec(TypedDict):
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
@@ -61,7 +60,6 @@ class ReferenceRow(TypedDict):
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
job_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
@@ -169,7 +167,6 @@ def batch_insert_seed_assets(
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"job_id": spec.get("job_id"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
|
||||
@@ -9,9 +9,6 @@ from sqlalchemy.orm import Session
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
count_active_siblings,
|
||||
create_stub_asset,
|
||||
ensure_tags_exist,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_file_path,
|
||||
@@ -26,8 +23,7 @@ from app.assets.database.queries import (
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
@@ -134,102 +130,6 @@ def _ingest_file_from_path(
|
||||
)
|
||||
|
||||
|
||||
def register_output_files(
|
||||
file_paths: Sequence[str],
|
||||
user_metadata: UserMetadata = None,
|
||||
job_id: str | None = None,
|
||||
) -> int:
|
||||
"""Register a batch of output file paths as assets.
|
||||
|
||||
Returns the number of files successfully registered.
|
||||
"""
|
||||
registered = 0
|
||||
for abs_path in file_paths:
|
||||
if not os.path.isfile(abs_path):
|
||||
continue
|
||||
try:
|
||||
if ingest_existing_file(
|
||||
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||
):
|
||||
registered += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to register output: %s", abs_path)
|
||||
return registered
|
||||
|
||||
|
||||
def ingest_existing_file(
|
||||
abs_path: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
extra_tags: Sequence[str] = (),
|
||||
owner_id: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Register an existing on-disk file as an asset stub.
|
||||
|
||||
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||
|
||||
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||
UX visibility.
|
||||
|
||||
Returns True if a row was inserted or updated, False otherwise.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||
|
||||
with create_session() as session:
|
||||
existing_ref = get_reference_by_file_path(session, locator)
|
||||
if existing_ref is not None:
|
||||
now = get_utc_now()
|
||||
existing_ref.mtime_ns = mtime_ns
|
||||
existing_ref.job_id = job_id
|
||||
existing_ref.is_missing = False
|
||||
existing_ref.deleted_at = None
|
||||
existing_ref.updated_at = now
|
||||
existing_ref.enrichment_level = 0
|
||||
|
||||
asset = existing_ref.asset
|
||||
if asset:
|
||||
# If other refs share this asset, detach to a new stub
|
||||
# instead of mutating the shared row.
|
||||
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||
if siblings > 0:
|
||||
new_asset = create_stub_asset(
|
||||
session,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type or asset.mime_type,
|
||||
)
|
||||
existing_ref.asset_id = new_asset.id
|
||||
else:
|
||||
asset.hash = None
|
||||
asset.size_bytes = size_bytes
|
||||
if mime_type:
|
||||
asset.mime_type = mime_type
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
spec = {
|
||||
"abs_path": abs_path,
|
||||
"size_bytes": size_bytes,
|
||||
"mtime_ns": mtime_ns,
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": os.path.basename(abs_path),
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": mime_type,
|
||||
"job_id": job_id,
|
||||
}
|
||||
if tags:
|
||||
ensure_tags_exist(session, tags)
|
||||
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||
session.commit()
|
||||
return result.won_paths > 0
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -55,7 +55,6 @@ total_vram = 0
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
training_fp8_bwd = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
|
||||
65
comfy/ops.py
65
comfy/ops.py
@@ -777,16 +777,8 @@ from .quant_ops import (
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||
|
||||
When training_fp8_bwd is enabled:
|
||||
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||
- Cached input is FP8 (half the memory of bf16)
|
||||
|
||||
When training_fp8_bwd is disabled:
|
||||
- Forward: quantize input per layout, use quantized matmul
|
||||
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -794,7 +786,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input for forward (same layout as weight)
|
||||
# Quantize input (same as inference path)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
@@ -805,68 +797,43 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Unflatten output to match original input shape
|
||||
# Restore original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
# Save for backward
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||
|
||||
if ctx.fp8_bwd:
|
||||
# Cache FP8 quantized input — half the memory of bf16
|
||||
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||
ctx.q_input = q_input # already FP8, reuse
|
||||
else:
|
||||
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||
ctx.save_for_backward(weight)
|
||||
else:
|
||||
ctx.q_input = None
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Value casting — only difference between fp8 and non-fp8 paths
|
||||
if ctx.fp8_bwd:
|
||||
weight, = ctx.saved_tensors
|
||||
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||
weight_mm = weight
|
||||
elif isinstance(weight, QuantizedTensor):
|
||||
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
else:
|
||||
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
input_mm = ctx.q_input
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
input_float, weight = ctx.saved_tensors
|
||||
# Standard tensors → torch.mm does regular matmul
|
||||
grad_mm = grad_2d
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_mm = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_mm = weight.to(compute_dtype)
|
||||
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||
weight_f = weight.to(compute_dtype)
|
||||
|
||||
# Computation — same for both paths, dispatch handles the rest
|
||||
grad_input = torch.mm(grad_mm, weight_mm)
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
|
||||
@@ -5,10 +5,6 @@ from comfy_api.latest._input import (
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
CurvePoint,
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -17,8 +13,4 @@ __all__ = [
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@@ -8,8 +7,4 @@ __all__ = [
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CurvePoint = tuple[float, float]
|
||||
|
||||
|
||||
class CurveInput(ABC):
|
||||
"""Abstract base class for curve inputs.
|
||||
|
||||
Subclasses represent different curve representations (control-point
|
||||
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||
uniform evaluation interface to downstream nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def points(self) -> list[CurvePoint]:
|
||||
"""The control points that define this curve."""
|
||||
|
||||
@abstractmethod
|
||||
def interp(self, x: float) -> float:
|
||||
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||
|
||||
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||
"""Vectorised evaluation over a numpy array of x values.
|
||||
|
||||
Subclasses should override this for better performance. The default
|
||||
falls back to scalar ``interp`` calls.
|
||||
"""
|
||||
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> CurveInput:
|
||||
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||
|
||||
Accepts:
|
||||
- A ``CurveInput`` instance (returned as-is).
|
||||
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||
"""
|
||||
if isinstance(data, CurveInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
raw_points = data["points"]
|
||||
interpolation = data.get("interpolation", "monotone_cubic")
|
||||
else:
|
||||
raw_points = data
|
||||
interpolation = "monotone_cubic"
|
||||
points = [(float(x), float(y)) for x, y in raw_points]
|
||||
if interpolation == "linear":
|
||||
return LinearCurve(points)
|
||||
if interpolation != "monotone_cubic":
|
||||
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||
return MonotoneCubicCurve(points)
|
||||
|
||||
|
||||
class MonotoneCubicCurve(CurveInput):
|
||||
"""Monotone cubic Hermite interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||
backend evaluation matches the editor preview exactly.
|
||||
|
||||
All heavy work (sorting, slope computation) happens once at construction.
|
||||
``interp_array`` is fully vectorised with numpy.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
self._slopes = self._compute_slopes()
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def _compute_slopes(self) -> np.ndarray:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return np.zeros(n, dtype=np.float64)
|
||||
|
||||
dx = np.diff(xs)
|
||||
dy = np.diff(ys)
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||
|
||||
slopes = np.empty(n, dtype=np.float64)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0.0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0.0
|
||||
slopes[i + 1] = 0.0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha * alpha + beta * beta
|
||||
if s > 9:
|
||||
t = 3 / math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
return slopes
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
if x <= xs[0]:
|
||||
return float(ys[0])
|
||||
if x >= xs[-1]:
|
||||
return float(ys[-1])
|
||||
|
||||
hi = int(np.searchsorted(xs, x, side='right'))
|
||||
hi = min(hi, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
if dx == 0:
|
||||
return float(ys[lo])
|
||||
|
||||
t = (x - xs[lo]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
"""Fully vectorised evaluation using numpy."""
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if n == 1:
|
||||
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||
|
||||
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
|
||||
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MonotoneCubicCurve(points={self._points})"
|
||||
|
||||
|
||||
class LinearCurve(CurveInput):
|
||||
"""Piecewise linear interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createLinearInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
return float(np.interp(x, xs, ys))
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
if len(self._xs) == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if len(self._xs) == 1:
|
||||
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||
return np.interp(xs_in, self._xs, self._ys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LinearCurve(points={self._points})"
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@@ -1242,9 +1242,8 @@ class BoundingBox(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
from comfy_api.input import CurvePoint
|
||||
if TYPE_CHECKING:
|
||||
Type = CurveInput_
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
@@ -1253,18 +1252,6 @@ class Curve(ComfyTypeIO):
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.default is not None:
|
||||
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
Type = list[int]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
@@ -2253,6 +2240,5 @@ __all__ = [
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -29,21 +29,13 @@ class ImageEditRequest(BaseModel):
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
image: InputUrlObject | None = Field(None)
|
||||
reference_images: list[InputUrlObject] | None = Field(None)
|
||||
image: InputUrlObject | None = Field(...)
|
||||
duration: int = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class VideoExtensionRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
video: InputUrlObject = Field(...)
|
||||
duration: int = Field(default=6)
|
||||
model: str | None = Field(default=None)
|
||||
|
||||
|
||||
class VideoEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
|
||||
@@ -8,7 +8,6 @@ from comfy_api_nodes.apis.grok import (
|
||||
ImageGenerationResponse,
|
||||
InputUrlObject,
|
||||
VideoEditRequest,
|
||||
VideoExtensionRequest,
|
||||
VideoGenerationRequest,
|
||||
VideoGenerationResponse,
|
||||
VideoStatusResponse,
|
||||
@@ -22,7 +21,6 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
@@ -35,13 +33,6 @@ def _extract_grok_price(response) -> float | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_grok_video_price(response) -> float | None:
|
||||
price = _extract_grok_price(response)
|
||||
if price is not None:
|
||||
return price * 1.43
|
||||
return None
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@@ -363,8 +354,6 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "grok-imagine-video-beta":
|
||||
model = "grok-imagine-video"
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
@@ -473,244 +462,6 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="api node/video/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="reference_",
|
||||
min=1,
|
||||
max=7,
|
||||
),
|
||||
tooltip="Up to 7 reference images to guide the video generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=6,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="The duration of the output video in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model.duration", "model.resolution"],
|
||||
input_groups=["model.reference_images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$refs := inputGroups["model.reference_images"];
|
||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
ref_image_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
list(model["reference_images"].values()),
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading base images",
|
||||
max_images=7,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model["model"],
|
||||
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||
prompt=prompt,
|
||||
resolution=model["resolution"],
|
||||
duration=model["duration"],
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoExtendNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="api node/video/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of what should happen next in the video.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=8,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Length of the extension in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video extension.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||
video_size = get_fs_object_size(video.get_stream_source())
|
||||
if video_size > 50 * 1024 * 1024:
|
||||
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||
data=VideoExtensionRequest(
|
||||
prompt=prompt,
|
||||
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||
duration=model["duration"],
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -718,9 +469,7 @@ class GrokExtension(ComfyExtension):
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokVideoNode,
|
||||
GrokVideoReferenceNode,
|
||||
GrokVideoEditNode,
|
||||
GrokVideoExtendNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.input import CurveInput
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CurveEditor(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CurveEditor",
|
||||
display_name="Curve Editor",
|
||||
category="utils",
|
||||
inputs=[
|
||||
io.Curve.Input("curve"),
|
||||
io.Histogram.Input("histogram", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Curve.Output("curve"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||
result = CurveInput.from_raw(curve)
|
||||
|
||||
ui = {}
|
||||
if histogram is not None:
|
||||
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||
|
||||
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||
|
||||
|
||||
class CurveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [CurveEditor]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return CurveExtension()
|
||||
@@ -1,85 +1,67 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@@ -100,7 +82,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@@ -124,14 +106,21 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@@ -157,163 +146,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@@ -325,131 +159,111 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if not self._display:
|
||||
raise RuntimeError("eglGetDisplay() returned no display")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
if not EGL.eglInitialize(self._display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
err = EGL.eglGetError()
|
||||
self._display = None
|
||||
raise RuntimeError(f"eglInitialize() failed (EGL error: 0x{err:04X})")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context)
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@@ -457,8 +271,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@@ -482,8 +298,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@@ -524,9 +342,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@@ -546,9 +361,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@@ -689,13 +504,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@@ -716,16 +531,16 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Number Convert node for unified numeric type conversion.
|
||||
|
||||
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||
inputs into FLOAT and INT outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class NumberConvertNode(io.ComfyNode):
|
||||
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ComfyNumberConvert",
|
||||
display_name="Number Convert",
|
||||
category="math",
|
||||
search_aliases=[
|
||||
"int to float", "float to int", "number convert",
|
||||
"int2float", "float2int", "cast", "parse number",
|
||||
"string to number", "bool to int",
|
||||
],
|
||||
inputs=[
|
||||
io.MultiType.Input(
|
||||
"value",
|
||||
[io.Int, io.Float, io.String, io.Boolean],
|
||||
display_name="value",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
if isinstance(value, bool):
|
||||
float_val = 1.0 if value else 0.0
|
||||
elif isinstance(value, (int, float)):
|
||||
float_val = float(value)
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise ValueError("Cannot convert empty string to number.")
|
||||
try:
|
||||
float_val = float(text)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Cannot convert string to number: {value!r}"
|
||||
) from None
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported input type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
|
||||
return io.NodeOutput(float_val, int(float_val))
|
||||
|
||||
|
||||
class NumberConvertExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [NumberConvertNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||
return NumberConvertExtension()
|
||||
@@ -1030,11 +1030,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for lora.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"quantized_backward",
|
||||
default=False,
|
||||
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"algorithm",
|
||||
options=list(adapter_maps.keys()),
|
||||
@@ -1102,7 +1097,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
quantized_backward,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
@@ -1123,7 +1117,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
quantized_backward = quantized_backward[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
@@ -1132,8 +1125,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
|
||||
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
latents = _process_latents_bucket_mode(latents)
|
||||
@@ -1146,7 +1137,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
@@ -1155,10 +1145,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
@@ -1169,6 +1156,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
|
||||
52
main.py
52
main.py
@@ -9,8 +9,6 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
@@ -139,16 +137,7 @@ def execute_prestartup_script():
|
||||
spec.loader.exec_module(module)
|
||||
return True
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
from nodes import NODE_STARTUP_ERRORS, get_module_name
|
||||
node_module_name = get_module_name(os.path.dirname(script_path))
|
||||
NODE_STARTUP_ERRORS[node_module_name] = {
|
||||
"module_path": os.path.dirname(script_path),
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"phase": "prestartup",
|
||||
}
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
@@ -203,6 +192,7 @@ if 'torch' in sys.modules:
|
||||
|
||||
|
||||
import comfy.utils
|
||||
from app.assets.seeder import asset_seeder
|
||||
|
||||
import execution
|
||||
import server
|
||||
@@ -250,38 +240,6 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
"""Extract absolute file paths for output items from a history result."""
|
||||
paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for node_output in history_result.get("outputs", {}).values():
|
||||
for items in node_output.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type not in ("output", "temp"):
|
||||
continue
|
||||
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||
if base_dir is None:
|
||||
continue
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
filename = item.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
abs_path = os.path.abspath(
|
||||
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||
)
|
||||
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||
continue
|
||||
if abs_path not in seen:
|
||||
seen.add(abs_path)
|
||||
paths.append(abs_path)
|
||||
return paths
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
@@ -316,7 +274,6 @@ def prompt_worker(q, server_instance):
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
@@ -339,10 +296,6 @@ def prompt_worker(q, server_instance):
|
||||
else:
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
paths = _collect_output_absolute_paths(e.history_result)
|
||||
register_output_files(paths, job_id=prompt_id)
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
@@ -364,9 +317,6 @@ def prompt_worker(q, server_instance):
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
|
||||
12
nodes.py
12
nodes.py
@@ -2181,9 +2181,6 @@ EXTENSION_WEB_DIRS = {}
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
# Dictionary of custom node startup errors, keyed by module name.
|
||||
NODE_STARTUP_ERRORS: dict[str, dict] = {}
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@@ -2301,13 +2298,6 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
module_name = get_module_name(module_path)
|
||||
NODE_STARTUP_ERRORS[module_name] = {
|
||||
"module_path": module_path,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"phase": "import",
|
||||
}
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
@@ -2464,9 +2454,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.36
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -33,5 +33,4 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
|
||||
@@ -753,10 +753,6 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/custom_node_startup_errors")
|
||||
async def get_custom_node_startup_errors(request):
|
||||
return web.json_response(nodes.NODE_STARTUP_ERRORS)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Base
|
||||
@@ -23,21 +23,6 @@ def db_engine():
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine_fk():
|
||||
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(db_engine):
|
||||
"""Session fixture for tests that need direct DB access."""
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.services.file_utils import get_mtime_ns
|
||||
from app.assets.scanner import (
|
||||
ENRICHMENT_HASHED,
|
||||
ENRICHMENT_METADATA,
|
||||
@@ -22,13 +20,6 @@ def _create_stub_asset(
|
||||
name: str | None = None,
|
||||
) -> tuple[Asset, AssetReference]:
|
||||
"""Create a stub asset with reference for testing enrichment."""
|
||||
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||
try:
|
||||
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||
mtime_ns = get_mtime_ns(stat_result)
|
||||
except OSError:
|
||||
mtime_ns = 1234567890000000000
|
||||
|
||||
asset = Asset(
|
||||
id=asset_id,
|
||||
hash=None,
|
||||
@@ -44,7 +35,7 @@ def _create_stub_asset(
|
||||
name=name or f"test-asset-{asset_id}",
|
||||
owner_id="system",
|
||||
file_path=file_path,
|
||||
mtime_ns=mtime_ns,
|
||||
mtime_ns=1234567890000000000,
|
||||
enrichment_level=ENRICHMENT_STUB,
|
||||
)
|
||||
session.add(ref)
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
"""Tests for ingest services."""
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session as SASession, Session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||
from app.assets.database.models import Asset, AssetReference, Tag
|
||||
from app.assets.database.queries import get_reference_tags
|
||||
from app.assets.services.ingest import (
|
||||
_ingest_file_from_path,
|
||||
_register_existing_asset,
|
||||
ingest_existing_file,
|
||||
)
|
||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
||||
|
||||
|
||||
class TestIngestFileFromPath:
|
||||
@@ -241,42 +235,3 @@ class TestRegisterExistingAsset:
|
||||
|
||||
assert result.created is True
|
||||
assert set(result.tags) == {"alpha", "beta"}
|
||||
|
||||
|
||||
class TestIngestExistingFileTagFK:
|
||||
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||
|
||||
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||
before they can be referenced in asset_reference_tags."""
|
||||
|
||||
@contextmanager
|
||||
def _create_session():
|
||||
with SASession(db_engine_fk) as sess:
|
||||
yield sess
|
||||
|
||||
file_path = temp_dir / "output.png"
|
||||
file_path.write_bytes(b"image data")
|
||||
|
||||
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||
patch(
|
||||
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||
return_value=("output.png", ["output"]),
|
||||
):
|
||||
result = ingest_existing_file(
|
||||
abs_path=str(file_path),
|
||||
extra_tags=["my-job"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
with SASession(db_engine_fk) as sess:
|
||||
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||
assert "output" in tag_names
|
||||
assert "my-job" in tag_names
|
||||
|
||||
ref_tags = sess.query(AssetReferenceTag).all()
|
||||
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||
assert "output" in ref_tag_names
|
||||
assert "my-job" in ref_tag_names
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||
|
||||
|
||||
class TestNumberConvertExecute:
|
||||
@staticmethod
|
||||
def _exec(value) -> object:
|
||||
return NumberConvertNode.execute(value)
|
||||
|
||||
# --- INT input ---
|
||||
|
||||
def test_int_input(self):
|
||||
result = self._exec(42)
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_int_zero(self):
|
||||
result = self._exec(0)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
def test_int_negative(self):
|
||||
result = self._exec(-7)
|
||||
assert result[0] == -7.0
|
||||
assert result[1] == -7
|
||||
|
||||
# --- FLOAT input ---
|
||||
|
||||
def test_float_input(self):
|
||||
result = self._exec(3.14)
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_float_truncation_toward_zero(self):
|
||||
result = self._exec(-2.9)
|
||||
assert result[0] == -2.9
|
||||
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||
|
||||
def test_float_output_type(self):
|
||||
result = self._exec(5)
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
def test_int_output_type(self):
|
||||
result = self._exec(5.7)
|
||||
assert isinstance(result[1], int)
|
||||
|
||||
# --- BOOL input ---
|
||||
|
||||
def test_bool_true(self):
|
||||
result = self._exec(True)
|
||||
assert result[0] == 1.0
|
||||
assert result[1] == 1
|
||||
|
||||
def test_bool_false(self):
|
||||
result = self._exec(False)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
# --- STRING input ---
|
||||
|
||||
def test_string_integer(self):
|
||||
result = self._exec("42")
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_string_float(self):
|
||||
result = self._exec("3.14")
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative(self):
|
||||
result = self._exec("-5.5")
|
||||
assert result[0] == -5.5
|
||||
assert result[1] == -5
|
||||
|
||||
def test_string_with_whitespace(self):
|
||||
result = self._exec(" 7.0 ")
|
||||
assert result[0] == 7.0
|
||||
assert result[1] == 7
|
||||
|
||||
def test_string_scientific_notation(self):
|
||||
result = self._exec("1e3")
|
||||
assert result[0] == 1000.0
|
||||
assert result[1] == 1000
|
||||
|
||||
# --- STRING error paths ---
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec("")
|
||||
|
||||
def test_whitespace_only_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec(" ")
|
||||
|
||||
def test_non_numeric_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||
self._exec("abc")
|
||||
|
||||
def test_string_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("inf")
|
||||
|
||||
def test_string_nan_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("nan")
|
||||
|
||||
def test_string_negative_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("-inf")
|
||||
|
||||
# --- Unsupported type ---
|
||||
|
||||
def test_unsupported_type_raises(self):
|
||||
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||
self._exec([1, 2, 3])
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -772,188 +771,6 @@ class TestSeederStopRestart:
|
||||
assert collected_roots[1] == ("input",)
|
||||
|
||||
|
||||
class TestEnqueueEnrichHandoff:
|
||||
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||
|
||||
def test_pending_enrich_runs_after_scan_completes(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""A queued enrich request runs automatically when a scan finishes."""
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
original_start = fresh_seeder.start
|
||||
|
||||
def tracking_start(*args, **kwargs):
|
||||
phase = kwargs.get("phase")
|
||||
roots = kwargs.get("roots", args[0] if args else None)
|
||||
result = original_start(*args, **kwargs)
|
||||
if phase == ScanPhase.ENRICH and result:
|
||||
enrich_roots_seen.append(roots)
|
||||
return result
|
||||
|
||||
fresh_seeder.start = tracking_start
|
||||
|
||||
# Start a fast scan, then enqueue an enrich while it's running
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=True
|
||||
)
|
||||
assert queued is False # queued, not started immediately
|
||||
|
||||
barrier.set()
|
||||
|
||||
# Wait for the original scan + the auto-started enrich scan
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
assert enrich_roots_seen == [("input",)]
|
||||
|
||||
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||
|
||||
Simulates the race: another thread calls enqueue_enrich right as the
|
||||
scan thread is draining _pending_enrich. The enqueue must either be
|
||||
picked up by the draining scan or successfully start its own scan.
|
||||
"""
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
enrich_started = threading.Event()
|
||||
|
||||
enrich_call_count = 0
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
# Track how many times start_enrich actually fires
|
||||
real_start_enrich = fresh_seeder.start_enrich
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
|
||||
def tracking_start_enrich(**kwargs):
|
||||
nonlocal enrich_call_count
|
||||
enrich_call_count += 1
|
||||
enrich_roots_seen.append(kwargs.get("roots"))
|
||||
result = real_start_enrich(**kwargs)
|
||||
if result:
|
||||
enrich_started.set()
|
||||
return result
|
||||
|
||||
fresh_seeder.start_enrich = tracking_start_enrich
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
# Start a scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
# Queue an enrich while scan is running
|
||||
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
# Let scan finish — drain will fire start_enrich atomically
|
||||
barrier.set()
|
||||
|
||||
# Wait for drain to complete and the enrich scan to start
|
||||
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||
assert ("output",) in enrich_roots_seen
|
||||
|
||||
def test_concurrent_enqueue_during_drain_not_lost(
|
||||
self, fresh_seeder: _AssetSeeder,
|
||||
):
|
||||
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||
|
||||
Because the drain now holds _lock through the start_enrich call,
|
||||
a concurrent enqueue_enrich will block until start_enrich has
|
||||
transitioned state to RUNNING, then the enqueue will queue its
|
||||
payload as _pending_enrich for the *next* drain.
|
||||
"""
|
||||
scan_barrier = threading.Event()
|
||||
scan_reached = threading.Event()
|
||||
enrich_barrier = threading.Event()
|
||||
enrich_reached = threading.Event()
|
||||
|
||||
collect_call = 0
|
||||
|
||||
def gated_collect(*args):
|
||||
nonlocal collect_call
|
||||
collect_call += 1
|
||||
if collect_call == 1:
|
||||
# First call: the initial fast scan
|
||||
scan_reached.set()
|
||||
scan_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
enrich_call = 0
|
||||
|
||||
def gated_get_unenriched(*args, **kwargs):
|
||||
nonlocal enrich_call
|
||||
enrich_call += 1
|
||||
if enrich_call == 1:
|
||||
# First enrich batch: signal and block
|
||||
enrich_reached.set()
|
||||
enrich_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
):
|
||||
# 1. Start fast scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert scan_reached.wait(timeout=2.0)
|
||||
|
||||
# 2. Queue enrich while fast scan is running
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=False
|
||||
)
|
||||
assert queued is False
|
||||
|
||||
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||
scan_barrier.set()
|
||||
|
||||
# 4. Wait until the drained enrich scan is running
|
||||
assert enrich_reached.wait(timeout=5.0)
|
||||
|
||||
# 5. Now enqueue another enrich while the drained scan is running
|
||||
queued2 = fresh_seeder.enqueue_enrich(
|
||||
roots=("output",), compute_hashes=True
|
||||
)
|
||||
assert queued2 is False # should be queued, not started
|
||||
|
||||
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||
with fresh_seeder._lock:
|
||||
assert fresh_seeder._pending_enrich is not None
|
||||
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||
|
||||
# Let the enrich scan finish
|
||||
enrich_barrier.set()
|
||||
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||
return UnenrichedReferenceRow(
|
||||
reference_id=ref_id, asset_id=asset_id,
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def seeder():
|
||||
"""Fresh seeder instance for each test."""
|
||||
return _AssetSeeder()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reset_to_idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetToIdle:
|
||||
def test_sets_idle_and_clears_progress(self, seeder):
|
||||
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||
seeder._state = State.RUNNING
|
||||
seeder._progress = progress
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is progress
|
||||
|
||||
def test_noop_when_progress_already_none(self, seeder):
|
||||
"""_reset_to_idle should handle None progress gracefully."""
|
||||
seeder._state = State.CANCELLING
|
||||
seeder._progress = None
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – immediate start when idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichStartsImmediately:
|
||||
def test_starts_when_idle(self, seeder):
|
||||
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||
|
||||
def test_no_pending_when_started_immediately(self, seeder):
|
||||
"""No pending request should be stored when start_enrich succeeds."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True):
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – queuing when busy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichQueuesWhenBusy:
|
||||
def test_queues_when_busy(self, seeder):
|
||||
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
|
||||
assert result is False
|
||||
assert seeder._pending_enrich == {
|
||||
"roots": ("models",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – merging when a pending request already exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichMergesPending:
|
||||
def _make_busy(self, seeder):
|
||||
"""Patch start_enrich to always return False (seeder busy)."""
|
||||
return patch.object(seeder, "start_enrich", return_value=False)
|
||||
|
||||
def test_merges_roots(self, seeder):
|
||||
"""A second enqueue should merge roots with the existing pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",))
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "output"}
|
||||
|
||||
def test_merges_overlapping_roots(self, seeder):
|
||||
"""Duplicate roots should be deduplicated."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models", "input"))
|
||||
seeder.enqueue_enrich(roots=("input", "output"))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
|
||||
def test_compute_hashes_sticky_true(self, seeder):
|
||||
"""Once compute_hashes is True it should stay True after merging."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_stays_false(self, seeder):
|
||||
"""If both enqueues have compute_hashes=False it stays False."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is False
|
||||
|
||||
def test_triple_merge(self, seeder):
|
||||
"""Three successive enqueues should all merge correctly."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending enrich drains after scan completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPendingEnrichDrain:
|
||||
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||
"""After a fast scan finishes, the pending enrich should be started."""
|
||||
seeder = _AssetSeeder()
|
||||
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": True,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_called_once_with(
|
||||
roots=("output",),
|
||||
compute_hashes=True,
|
||||
)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||
seeder = _AssetSeeder()
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_no_drain_when_no_pending(self, *_mocks):
|
||||
"""start_enrich should not be called when there is no pending request."""
|
||||
seeder = _AssetSeeder()
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-safety of enqueue_enrich
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichThreadSafety:
|
||||
def test_concurrent_enqueues(self, seeder):
|
||||
"""Multiple threads enqueuing should not lose roots."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
barrier = threading.Barrier(3)
|
||||
|
||||
def enqueue(root):
|
||||
barrier.wait()
|
||||
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=enqueue, args=(r,))
|
||||
for r in ("models", "input", "output")
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
Reference in New Issue
Block a user