mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 20:50:01 +00:00
Compare commits
12 Commits
pysssss/ba
...
feat/cache
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0141af0786 | ||
|
|
0440ebcf6e | ||
|
|
4afa80dc07 | ||
|
|
d755f7ca19 | ||
|
|
2049066cff | ||
|
|
9b0ca8b95c | ||
|
|
dcf686857c | ||
|
|
17eed38750 | ||
|
|
f4623c0e1b | ||
|
|
5e4bbca1ad | ||
|
|
e17571d9be | ||
|
|
6540aa0400 |
@@ -106,6 +106,42 @@ class Types:
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
|
||||
|
||||
class Caching:
|
||||
"""
|
||||
External cache provider API for distributed caching.
|
||||
|
||||
Enables sharing cached results across multiple ComfyUI instances
|
||||
(e.g., Kubernetes pods) without monkey-patching internal methods.
|
||||
|
||||
Example usage:
|
||||
from comfy_api.latest import Caching
|
||||
|
||||
class MyRedisProvider(Caching.CacheProvider):
|
||||
def on_lookup(self, context):
|
||||
# Check Redis for cached result
|
||||
...
|
||||
|
||||
def on_store(self, context, value):
|
||||
# Store to Redis (can be async internally)
|
||||
...
|
||||
|
||||
Caching.register_provider(MyRedisProvider())
|
||||
"""
|
||||
# Import from comfy_execution.cache_provider (source of truth)
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider,
|
||||
CacheContext,
|
||||
CacheValue,
|
||||
register_cache_provider as register_provider,
|
||||
unregister_cache_provider as unregister_provider,
|
||||
get_cache_providers as get_providers,
|
||||
has_cache_providers as has_providers,
|
||||
clear_cache_providers as clear_providers,
|
||||
estimate_value_size,
|
||||
)
|
||||
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
# Create a synchronous version of the API
|
||||
@@ -125,6 +161,7 @@ __all__ = [
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"Caching",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"IO",
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str = Field(...)
|
||||
n: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
response_for: str = Field("url")
|
||||
|
||||
|
||||
class InputUrlObject(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class ImageEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
image: InputUrlObject = Field(...)
|
||||
prompt: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
n: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
response_for: str = Field("url")
|
||||
|
||||
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
image: InputUrlObject | None = Field(...)
|
||||
duration: int = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class VideoEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
video: InputUrlObject = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class ImageResponseObject(BaseModel):
|
||||
url: str | None = Field(None)
|
||||
b64_json: str | None = Field(None)
|
||||
revised_prompt: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
data: list[ImageResponseObject] = Field(...)
|
||||
|
||||
|
||||
class VideoGenerationResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
class VideoResponseObject(BaseModel):
|
||||
url: str = Field(...)
|
||||
upsampled_prompt: str | None = Field(None)
|
||||
duration: int = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponse(BaseModel):
|
||||
status: str | None = Field(None)
|
||||
video: VideoResponseObject | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
@@ -1,417 +0,0 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.grok import (
|
||||
ImageEditRequest,
|
||||
ImageGenerationRequest,
|
||||
ImageGenerationResponse,
|
||||
InputUrlObject,
|
||||
VideoEditRequest,
|
||||
VideoGenerationRequest,
|
||||
VideoGenerationResponse,
|
||||
VideoStatusResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_fs_object_size,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageNode",
|
||||
display_name="Grok Image",
|
||||
category="api node/image/Grok",
|
||||
description="Generate images using Grok based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="The text prompt used to generate the image",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"9:19.5",
|
||||
"19.5:9",
|
||||
"9:20",
|
||||
"20:9",
|
||||
"1:2",
|
||||
"2:1",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"number_of_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Number of images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
||||
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
number_of_images: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"),
|
||||
data=ImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
return IO.NodeOutput(
|
||||
torch.cat(
|
||||
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class GrokImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNode",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="The text prompt used to generate the image",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K"]),
|
||||
IO.Int.Input(
|
||||
"number_of_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Number of edited images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
||||
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
number_of_images: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
|
||||
data=ImageEditRequest(
|
||||
model=model,
|
||||
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
|
||||
prompt=prompt,
|
||||
resolution=resolution.lower(),
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
return IO.NodeOutput(
|
||||
torch.cat(
|
||||
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class GrokVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoNode",
|
||||
display_name="Grok Video",
|
||||
category="api node/video/Grok",
|
||||
description="Generate video from a prompt or an image",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=6,
|
||||
min=1,
|
||||
max=15,
|
||||
step=1,
|
||||
tooltip="The duration of the output video in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Image.Input("image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
|
||||
expr="""
|
||||
(
|
||||
$base := 0.181 * widgets.duration;
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}")
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model,
|
||||
image=image_url,
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoEditNode",
|
||||
display_name="Grok Video Edit",
|
||||
category="api node/video/Grok",
|
||||
description="Edit an existing video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_video_duration(video, min_duration=1, max_duration=8.7)
|
||||
video_stream = video.get_stream_source()
|
||||
video_size = get_fs_object_size(video_stream)
|
||||
if video_size > 50 * 1024 * 1024:
|
||||
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"),
|
||||
data=VideoEditRequest(
|
||||
model=model,
|
||||
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokVideoNode,
|
||||
GrokVideoEditNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GrokExtension:
|
||||
return GrokExtension()
|
||||
319
comfy_execution/cache_provider.py
Normal file
319
comfy_execution/cache_provider.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
External Cache Provider API for distributed caching.
|
||||
|
||||
This module provides a public API for external cache providers, enabling
|
||||
distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods).
|
||||
|
||||
Public API is also available via:
|
||||
from comfy_api.latest import Caching
|
||||
|
||||
Example usage:
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider, CacheContext, CacheValue, register_cache_provider
|
||||
)
|
||||
|
||||
class MyRedisProvider(CacheProvider):
|
||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
# Check Redis/GCS for cached result
|
||||
...
|
||||
|
||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
# Store to Redis/GCS (can be async internally)
|
||||
...
|
||||
|
||||
register_cache_provider(MyRedisProvider())
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import pickle
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data Classes
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class CacheContext:
|
||||
"""Context passed to provider methods."""
|
||||
prompt_id: str # Current prompt execution ID
|
||||
node_id: str # Node being cached
|
||||
class_type: str # Node class type (e.g., "KSampler")
|
||||
cache_key: Any # Raw cache key (frozenset structure)
|
||||
cache_key_bytes: bytes # SHA256 hash for external storage key
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheValue:
|
||||
"""
|
||||
Value stored/retrieved from external cache.
|
||||
|
||||
The ui field is optional - implementations may choose to skip it
|
||||
(e.g., if it contains non-portable data like local file paths).
|
||||
"""
|
||||
outputs: list # The tensor/value outputs
|
||||
ui: dict = None # Optional UI data (may be skipped by implementations)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Provider Interface
|
||||
# ============================================================
|
||||
|
||||
class CacheProvider(ABC):
|
||||
"""
|
||||
Abstract base class for external cache providers.
|
||||
|
||||
Thread Safety:
|
||||
Providers may be called from multiple threads. Implementations
|
||||
must be thread-safe.
|
||||
|
||||
Error Handling:
|
||||
All methods are wrapped in try/except by the caller. Exceptions
|
||||
are logged but never propagate to break execution.
|
||||
|
||||
Performance Guidelines:
|
||||
- on_lookup: Should complete in <500ms (including network)
|
||||
- on_store: Can be async internally (fire-and-forget)
|
||||
- should_cache: Should be fast (<1ms), called frequently
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
"""
|
||||
Check external storage for cached result.
|
||||
|
||||
Called AFTER local cache miss (local-first for performance).
|
||||
|
||||
Returns:
|
||||
CacheValue if found externally, None otherwise.
|
||||
|
||||
Important:
|
||||
- Return None on any error (don't raise)
|
||||
- Validate data integrity before returning
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
"""
|
||||
Store value to external cache.
|
||||
|
||||
Called AFTER value is stored in local cache.
|
||||
|
||||
Important:
|
||||
- Can be fire-and-forget (async internally)
|
||||
- Should never block execution
|
||||
- Handle serialization failures gracefully
|
||||
"""
|
||||
pass
|
||||
|
||||
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||
"""
|
||||
Filter which nodes should be externally cached.
|
||||
|
||||
Called before on_lookup (value=None) and on_store (value provided).
|
||||
Return False to skip external caching for this node.
|
||||
|
||||
Implementations can filter based on context.class_type, value size,
|
||||
or any custom logic. Use estimate_value_size() to get value size.
|
||||
|
||||
Default: Returns True (cache everything).
|
||||
"""
|
||||
return True
|
||||
|
||||
def on_prompt_start(self, prompt_id: str) -> None:
|
||||
"""Called when prompt execution begins. Optional."""
|
||||
pass
|
||||
|
||||
def on_prompt_end(self, prompt_id: str) -> None:
|
||||
"""Called when prompt execution ends. Optional."""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Provider Registry
|
||||
# ============================================================
|
||||
|
||||
_providers: List[CacheProvider] = []
|
||||
_providers_lock = threading.Lock()
|
||||
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
|
||||
|
||||
|
||||
def register_cache_provider(provider: CacheProvider) -> None:
|
||||
"""
|
||||
Register an external cache provider.
|
||||
|
||||
Providers are called in registration order. First provider to return
|
||||
a result from on_lookup wins.
|
||||
"""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
if provider in _providers:
|
||||
logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||
return
|
||||
_providers.append(provider)
|
||||
_providers_snapshot = None # Invalidate cache
|
||||
logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
||||
|
||||
|
||||
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||
"""Remove a previously registered provider."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
try:
|
||||
_providers.remove(provider)
|
||||
_providers_snapshot = None
|
||||
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||
except ValueError:
|
||||
logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||
|
||||
|
||||
def get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||
"""Get registered providers (cached for performance)."""
|
||||
global _providers_snapshot
|
||||
snapshot = _providers_snapshot
|
||||
if snapshot is not None:
|
||||
return snapshot
|
||||
with _providers_lock:
|
||||
if _providers_snapshot is not None:
|
||||
return _providers_snapshot
|
||||
_providers_snapshot = tuple(_providers)
|
||||
return _providers_snapshot
|
||||
|
||||
|
||||
def has_cache_providers() -> bool:
|
||||
"""Fast check if any providers registered (no lock)."""
|
||||
return bool(_providers)
|
||||
|
||||
|
||||
def clear_cache_providers() -> None:
|
||||
"""Remove all providers. Useful for testing."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
_providers.clear()
|
||||
_providers_snapshot = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Utilities
|
||||
# ============================================================
|
||||
|
||||
def _canonicalize(obj: Any) -> Any:
|
||||
"""
|
||||
Convert an object to a canonical, JSON-serializable form.
|
||||
|
||||
This ensures deterministic ordering regardless of Python's hash randomization,
|
||||
which is critical for cross-pod cache key consistency. Frozensets in particular
|
||||
have non-deterministic iteration order between Python sessions.
|
||||
"""
|
||||
if isinstance(obj, frozenset):
|
||||
# Sort frozenset items for deterministic ordering
|
||||
return ("__frozenset__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, tuple):
|
||||
return ("__tuple__", [_canonicalize(item) for item in obj])
|
||||
elif isinstance(obj, list):
|
||||
return [_canonicalize(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {str(k): _canonicalize(v) for k, v in sorted(obj.items())}
|
||||
elif isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, bytes):
|
||||
return ("__bytes__", obj.hex())
|
||||
elif hasattr(obj, 'value'):
|
||||
# Handle Unhashable class from ComfyUI
|
||||
return ("__unhashable__", _canonicalize(getattr(obj, 'value', None)))
|
||||
else:
|
||||
# For other types, use repr as fallback
|
||||
return ("__repr__", repr(obj))
|
||||
|
||||
|
||||
def serialize_cache_key(cache_key: Any) -> bytes:
|
||||
"""
|
||||
Serialize cache key to bytes for external storage.
|
||||
|
||||
Returns SHA256 hash suitable for Redis/database keys.
|
||||
|
||||
Note: Uses canonicalize + JSON serialization instead of pickle because
|
||||
pickle is NOT deterministic across Python sessions due to hash randomization
|
||||
affecting frozenset iteration order. This is critical for distributed caching
|
||||
where different pods need to compute the same hash for identical inputs.
|
||||
"""
|
||||
try:
|
||||
canonical = _canonicalize(cache_key)
|
||||
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||
return hashlib.sha256(json_str.encode('utf-8')).digest()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to serialize cache key: {e}")
|
||||
# Fallback to pickle (non-deterministic but better than nothing)
|
||||
try:
|
||||
serialized = pickle.dumps(cache_key, protocol=4)
|
||||
return hashlib.sha256(serialized).digest()
|
||||
except Exception:
|
||||
return hashlib.sha256(str(id(cache_key)).encode()).digest()
|
||||
|
||||
|
||||
def contains_nan(obj: Any) -> bool:
|
||||
"""
|
||||
Check if cache key contains NaN (indicates uncacheable node).
|
||||
|
||||
NaN != NaN in Python, so local cache never hits. But serialized
|
||||
NaN would match, causing incorrect external hits. Must skip these.
|
||||
"""
|
||||
if isinstance(obj, float):
|
||||
try:
|
||||
return math.isnan(obj)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if hasattr(obj, 'value'): # Unhashable class
|
||||
val = getattr(obj, 'value', None)
|
||||
if isinstance(val, float):
|
||||
try:
|
||||
return math.isnan(val)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||
return any(contains_nan(item) for item in obj)
|
||||
if isinstance(obj, dict):
|
||||
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items())
|
||||
return False
|
||||
|
||||
|
||||
def estimate_value_size(value: CacheValue) -> int:
|
||||
"""Estimate serialized size in bytes. Useful for size-based filtering."""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
|
||||
def estimate(obj):
|
||||
nonlocal total
|
||||
if isinstance(obj, torch.Tensor):
|
||||
total += obj.numel() * obj.element_size()
|
||||
elif isinstance(obj, dict):
|
||||
for v in obj.values():
|
||||
estimate(v)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
estimate(item)
|
||||
|
||||
for output in value.outputs:
|
||||
estimate(output)
|
||||
return total
|
||||
@@ -155,6 +155,10 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
# External cache provider support
|
||||
self._is_subcache = False
|
||||
self._current_prompt_id = ''
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
@@ -201,20 +205,123 @@ class BasicCache:
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
# Notify external providers
|
||||
self._notify_providers_store(node_id, cache_key, value)
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
|
||||
# Check local cache first (fast path)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
|
||||
# Check external providers on local miss
|
||||
external_result = self._check_providers_lookup(node_id, cache_key)
|
||||
if external_result is not None:
|
||||
self.cache[cache_key] = external_result # Warm local cache
|
||||
return external_result
|
||||
|
||||
return None
|
||||
|
||||
def _notify_providers_store(self, node_id, cache_key, value):
|
||||
"""Notify external providers of cache store."""
|
||||
from comfy_execution.cache_provider import (
|
||||
has_cache_providers, get_cache_providers,
|
||||
CacheContext, CacheValue,
|
||||
serialize_cache_key, contains_nan, logger
|
||||
)
|
||||
|
||||
# Fast exit conditions
|
||||
if self._is_subcache:
|
||||
return
|
||||
if not has_cache_providers():
|
||||
return
|
||||
if not self._is_external_cacheable_value(value):
|
||||
return
|
||||
if contains_nan(cache_key):
|
||||
return
|
||||
|
||||
context = CacheContext(
|
||||
prompt_id=self._current_prompt_id,
|
||||
node_id=node_id,
|
||||
class_type=self._get_class_type(node_id),
|
||||
cache_key=cache_key,
|
||||
cache_key_bytes=serialize_cache_key(cache_key)
|
||||
)
|
||||
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
||||
|
||||
for provider in get_cache_providers():
|
||||
try:
|
||||
if provider.should_cache(context, cache_value):
|
||||
provider.on_store(context, cache_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||
|
||||
def _check_providers_lookup(self, node_id, cache_key):
|
||||
"""Check external providers for cached result."""
|
||||
from comfy_execution.cache_provider import (
|
||||
has_cache_providers, get_cache_providers,
|
||||
CacheContext, CacheValue,
|
||||
serialize_cache_key, contains_nan, logger
|
||||
)
|
||||
|
||||
if self._is_subcache:
|
||||
return None
|
||||
if not has_cache_providers():
|
||||
return None
|
||||
if contains_nan(cache_key):
|
||||
return None
|
||||
|
||||
context = CacheContext(
|
||||
prompt_id=self._current_prompt_id,
|
||||
node_id=node_id,
|
||||
class_type=self._get_class_type(node_id),
|
||||
cache_key=cache_key,
|
||||
cache_key_bytes=serialize_cache_key(cache_key)
|
||||
)
|
||||
|
||||
for provider in get_cache_providers():
|
||||
try:
|
||||
if not provider.should_cache(context):
|
||||
continue
|
||||
result = provider.on_lookup(context)
|
||||
if result is not None:
|
||||
if not isinstance(result, CacheValue):
|
||||
logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||
continue
|
||||
if not isinstance(result.outputs, (list, tuple)):
|
||||
logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||
continue
|
||||
# Import CacheEntry here to avoid circular import at module level
|
||||
from execution import CacheEntry
|
||||
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _is_external_cacheable_value(self, value):
|
||||
"""Check if value is a CacheEntry suitable for external caching (not objects cache)."""
|
||||
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||
|
||||
def _get_class_type(self, node_id):
|
||||
"""Get class_type for a node."""
|
||||
if not self.initialized or not self.dynprompt:
|
||||
return ''
|
||||
try:
|
||||
return self.dynprompt.get_node(node_id).get('class_type', '')
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
subcache._is_subcache = True # Mark as subcache - excludes from external caching
|
||||
subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID
|
||||
self.subcaches[subcache_key] = subcache
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
@@ -1,810 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||
# Draws a single triangle that covers the entire screen:
|
||||
#
|
||||
# (-1,3)
|
||||
# /|
|
||||
# / | <- visible area is the unit square from (-1,-1) to (1,1)
|
||||
# / | parts outside get clipped away
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
|
||||
gl_Position = vec4(verts[gl_VertexID], 0, 1);
|
||||
}
|
||||
"""
|
||||
|
||||
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform vec2 u_resolution;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
void main() {
|
||||
fragColor0 = texture(u_image0, v_texCoord);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
"""Detect how many fragColor outputs are used in the shader.
|
||||
|
||||
Returns the count of outputs needed (1 to MAX_OUTPUTS).
|
||||
"""
|
||||
matches = re.findall(r"fragColor(\d+)", source)
|
||||
if not matches:
|
||||
return 1 # Default to 1 output if none found
|
||||
max_index = max(int(m) for m in matches)
|
||||
return min(max_index + 1, MAX_OUTPUTS)
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
GLContext._initialized = True
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
)
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
self._vao = None
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
shader = gl.glCreateShader(shader_type)
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
return shader
|
||||
|
||||
|
||||
def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
"""Create and link a shader program."""
|
||||
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
|
||||
try:
|
||||
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
|
||||
except RuntimeError:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
raise
|
||||
|
||||
program = gl.glCreateProgram()
|
||||
gl.glAttachShader(program, vertex_shader)
|
||||
gl.glAttachShader(program, fragment_shader)
|
||||
gl.glLinkProgram(program)
|
||||
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
return program
|
||||
|
||||
|
||||
def _render_shader_batch(
|
||||
fragment_code: str,
|
||||
width: int,
|
||||
height: int,
|
||||
image_batches: list[list[np.ndarray]],
|
||||
floats: list[float],
|
||||
ints: list[int],
|
||||
) -> list[list[np.ndarray]]:
|
||||
"""
|
||||
Render a fragment shader for multiple batches efficiently.
|
||||
|
||||
Compiles shader once, reuses framebuffer/textures across batches.
|
||||
|
||||
Args:
|
||||
fragment_code: User's fragment shader code
|
||||
width: Output width
|
||||
height: Output height
|
||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||
floats: List of float uniforms
|
||||
ints: List of int uniforms
|
||||
|
||||
Returns:
|
||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||
"""
|
||||
if not image_batches:
|
||||
return []
|
||||
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
# Track resources for cleanup
|
||||
program = None
|
||||
fbo = None
|
||||
output_textures = []
|
||||
input_textures = []
|
||||
|
||||
num_inputs = len(image_batches[0])
|
||||
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
|
||||
# Create framebuffer with only the needed color attachments
|
||||
fbo = gl.glGenFramebuffers(1)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
|
||||
draw_buffers = []
|
||||
for i in range(num_outputs):
|
||||
tex = gl.glGenTextures(1)
|
||||
output_textures.append(tex)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
|
||||
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
|
||||
gl.glDrawBuffers(num_outputs, draw_buffers)
|
||||
|
||||
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
||||
raise RuntimeError("Framebuffer is not complete")
|
||||
|
||||
# Create input textures (reused for all batches)
|
||||
for i in range(num_inputs):
|
||||
tex = gl.glGenTextures(1)
|
||||
input_textures.append(tex)
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||
|
||||
loc = gl.glGetUniformLocation(program, f"u_image{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, i)
|
||||
|
||||
# Set static uniforms (once for all batches)
|
||||
loc = gl.glGetUniformLocation(program, "u_resolution")
|
||||
if loc >= 0:
|
||||
gl.glUniform2f(loc, float(width), float(height))
|
||||
|
||||
for i, v in enumerate(floats):
|
||||
loc = gl.glGetUniformLocation(program, f"u_float{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1f(loc, v)
|
||||
|
||||
for i, v in enumerate(ints):
|
||||
loc = gl.glGetUniformLocation(program, f"u_int{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, v)
|
||||
|
||||
gl.glViewport(0, 0, width, height)
|
||||
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
|
||||
|
||||
# Process each batch
|
||||
all_batch_outputs = []
|
||||
for images in image_batches:
|
||||
# Update input textures with this batch's images
|
||||
for i, img in enumerate(images):
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
|
||||
|
||||
# Flip vertically for GL coordinates, ensure RGBA
|
||||
h, w, c = img.shape
|
||||
if c == 3:
|
||||
img_upload = np.empty((h, w, 4), dtype=np.float32)
|
||||
img_upload[:, :, :3] = img[::-1, :, :]
|
||||
img_upload[:, :, 3] = 1.0
|
||||
else:
|
||||
img_upload = np.ascontiguousarray(img[::-1, :, :])
|
||||
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
|
||||
|
||||
# Render
|
||||
gl.glClearColor(0, 0, 0, 0)
|
||||
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(np.ascontiguousarray(img[::-1, :, :]))
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
for _ in range(num_outputs, MAX_OUTPUTS):
|
||||
batch_outputs.append(black_img)
|
||||
|
||||
all_batch_outputs.append(batch_outputs)
|
||||
|
||||
return all_batch_outputs
|
||||
|
||||
finally:
|
||||
# Unbind before deleting
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
class GLSLShader(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
image_template = io.Autogrow.TemplatePrefix(
|
||||
io.Image.Input("image"),
|
||||
prefix="image",
|
||||
min=1,
|
||||
max=MAX_IMAGES,
|
||||
)
|
||||
|
||||
float_template = io.Autogrow.TemplatePrefix(
|
||||
io.Float.Input("float", default=0.0),
|
||||
prefix="u_float",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
int_template = io.Autogrow.TemplatePrefix(
|
||||
io.Int.Input("int", default=0),
|
||||
prefix="u_int",
|
||||
min=0,
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
category="image/shader",
|
||||
description=(
|
||||
f"Apply GLSL fragment shaders to images. "
|
||||
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
|
||||
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
|
||||
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
|
||||
),
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
default=DEFAULT_FRAGMENT_SHADER,
|
||||
multiline=True,
|
||||
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
||||
),
|
||||
io.DynamicCombo.Input(
|
||||
"size_mode",
|
||||
options=[
|
||||
io.DynamicCombo.Option("from_input", []),
|
||||
io.DynamicCombo.Option(
|
||||
"custom",
|
||||
[
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=1,
|
||||
max=nodes.MAX_RESOLUTION,
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=1,
|
||||
max=nodes.MAX_RESOLUTION,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
||||
),
|
||||
io.Autogrow.Input("images", template=image_template),
|
||||
io.Autogrow.Input("floats", template=float_template),
|
||||
io.Autogrow.Input("ints", template=int_template),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0"),
|
||||
io.Image.Output(display_name="IMAGE1"),
|
||||
io.Image.Output(display_name="IMAGE2"),
|
||||
io.Image.Output(display_name="IMAGE3"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
fragment_shader: str,
|
||||
size_mode: SizeModeInput,
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = (
|
||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
)
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
|
||||
# Determine output dimensions
|
||||
if size_mode["size_mode"] == "custom":
|
||||
out_width = size_mode["width"]
|
||||
out_height = size_mode["height"]
|
||||
else:
|
||||
out_height, out_width = image_list[0].shape[1:3]
|
||||
|
||||
batch_size = image_list[0].shape[0]
|
||||
|
||||
# Prepare batches
|
||||
image_batches = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
|
||||
image_batches.append(batch_images)
|
||||
|
||||
all_batch_outputs = _render_shader_batch(
|
||||
fragment_shader,
|
||||
out_width,
|
||||
out_height,
|
||||
image_batches,
|
||||
float_list,
|
||||
int_list,
|
||||
)
|
||||
|
||||
# Collect outputs into tensors
|
||||
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
|
||||
for batch_outputs in all_batch_outputs:
|
||||
for i, out_img in enumerate(batch_outputs):
|
||||
all_outputs[i].append(torch.from_numpy(out_img))
|
||||
|
||||
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
|
||||
return io.NodeOutput(
|
||||
*output_tensors,
|
||||
ui=cls._build_ui_output(image_list, output_tensors[0]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_ui_output(
|
||||
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
||||
) -> dict[str, list]:
|
||||
"""Build UI output with input and output images for client-side shader execution."""
|
||||
combined_inputs = torch.cat(image_list, dim=0)
|
||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
||||
combined_inputs,
|
||||
filename_prefix="GLSLShader_input",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
output_images_ui = ui.ImageSaveHelper.save_images(
|
||||
output_batch,
|
||||
filename_prefix="GLSLShader_output",
|
||||
folder_type=io.FolderType.temp,
|
||||
cls=None,
|
||||
compress_level=1,
|
||||
)
|
||||
|
||||
return {"input_images": input_images_ui, "images": output_images_ui}
|
||||
|
||||
|
||||
class GLSLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [GLSLShader]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> GLSLExtension:
|
||||
return GLSLExtension()
|
||||
137
execution.py
137
execution.py
@@ -669,6 +669,22 @@ class PromptExecutor:
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||
"""Notify external cache providers of prompt lifecycle events."""
|
||||
from comfy_execution.cache_provider import has_cache_providers, get_cache_providers, logger
|
||||
|
||||
if not has_cache_providers():
|
||||
return
|
||||
|
||||
for provider in get_cache_providers():
|
||||
try:
|
||||
if event == "start":
|
||||
provider.on_prompt_start(prompt_id)
|
||||
elif event == "end":
|
||||
provider.on_prompt_end(prompt_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
@@ -685,66 +701,77 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
# Set prompt ID on caches for external provider integration
|
||||
for cache in self.caches.all:
|
||||
cache._current_prompt_id = prompt_id
|
||||
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
# Notify external cache providers of prompt start
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
# Notify external cache providers of prompt end
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b1
|
||||
comfyui_manager==4.0.5
|
||||
|
||||
1
nodes.py
1
nodes.py
@@ -2432,7 +2432,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
"nodes_lora_debug.py"
|
||||
]
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.37.11
|
||||
comfyui-workflow-templates==0.8.27
|
||||
comfyui-workflow-templates==0.8.24
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
@@ -29,6 +29,3 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
PyOpenGL-accelerate
|
||||
glfw
|
||||
|
||||
370
tests-unit/execution_test/test_cache_provider.py
Normal file
370
tests-unit/execution_test/test_cache_provider.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""Tests for external cache provider API."""
|
||||
|
||||
import importlib.util
|
||||
import pytest
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _torch_available() -> bool:
|
||||
"""Check if PyTorch is available."""
|
||||
return importlib.util.find_spec("torch") is not None
|
||||
|
||||
|
||||
from comfy_execution.cache_provider import (
|
||||
CacheProvider,
|
||||
CacheContext,
|
||||
CacheValue,
|
||||
register_cache_provider,
|
||||
unregister_cache_provider,
|
||||
get_cache_providers,
|
||||
has_cache_providers,
|
||||
clear_cache_providers,
|
||||
serialize_cache_key,
|
||||
contains_nan,
|
||||
estimate_value_size,
|
||||
_canonicalize,
|
||||
)
|
||||
|
||||
|
||||
class TestCanonicalize:
|
||||
"""Test _canonicalize function for deterministic ordering."""
|
||||
|
||||
def test_frozenset_ordering_is_deterministic(self):
|
||||
"""Frozensets should produce consistent canonical form regardless of iteration order."""
|
||||
# Create two frozensets with same content
|
||||
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
|
||||
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
|
||||
|
||||
result1 = _canonicalize(fs1)
|
||||
result2 = _canonicalize(fs2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_nested_frozenset_ordering(self):
|
||||
"""Nested frozensets should also be deterministically ordered."""
|
||||
inner1 = frozenset([1, 2, 3])
|
||||
inner2 = frozenset([3, 2, 1])
|
||||
|
||||
fs1 = frozenset([("key", inner1)])
|
||||
fs2 = frozenset([("key", inner2)])
|
||||
|
||||
result1 = _canonicalize(fs1)
|
||||
result2 = _canonicalize(fs2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_dict_ordering(self):
|
||||
"""Dicts should be sorted by key."""
|
||||
d1 = {"z": 1, "a": 2, "m": 3}
|
||||
d2 = {"a": 2, "m": 3, "z": 1}
|
||||
|
||||
result1 = _canonicalize(d1)
|
||||
result2 = _canonicalize(d2)
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_tuple_preserved(self):
|
||||
"""Tuples should be marked and preserved."""
|
||||
t = (1, 2, 3)
|
||||
result = _canonicalize(t)
|
||||
|
||||
assert result[0] == "__tuple__"
|
||||
assert result[1] == [1, 2, 3]
|
||||
|
||||
def test_list_preserved(self):
|
||||
"""Lists should be recursively canonicalized."""
|
||||
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
|
||||
result = _canonicalize(lst)
|
||||
|
||||
# First element should be dict with sorted keys
|
||||
assert result[0] == {"a": 1, "b": 2}
|
||||
# Second element should be canonicalized frozenset
|
||||
assert result[1][0] == "__frozenset__"
|
||||
|
||||
def test_primitives_unchanged(self):
|
||||
"""Primitive types should pass through unchanged."""
|
||||
assert _canonicalize(42) == 42
|
||||
assert _canonicalize(3.14) == 3.14
|
||||
assert _canonicalize("hello") == "hello"
|
||||
assert _canonicalize(True) is True
|
||||
assert _canonicalize(None) is None
|
||||
|
||||
def test_bytes_converted(self):
|
||||
"""Bytes should be converted to hex string."""
|
||||
b = b"\x00\xff"
|
||||
result = _canonicalize(b)
|
||||
|
||||
assert result[0] == "__bytes__"
|
||||
assert result[1] == "00ff"
|
||||
|
||||
def test_set_ordering(self):
|
||||
"""Sets should be sorted like frozensets."""
|
||||
s1 = {3, 1, 2}
|
||||
s2 = {1, 2, 3}
|
||||
|
||||
result1 = _canonicalize(s1)
|
||||
result2 = _canonicalize(s2)
|
||||
|
||||
assert result1 == result2
|
||||
assert result1[0] == "__set__"
|
||||
|
||||
|
||||
class TestSerializeCacheKey:
|
||||
"""Test serialize_cache_key for deterministic hashing."""
|
||||
|
||||
def test_same_content_same_hash(self):
|
||||
"""Same content should produce same hash."""
|
||||
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||
|
||||
hash1 = serialize_cache_key(key1)
|
||||
hash2 = serialize_cache_key(key2)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_different_content_different_hash(self):
|
||||
"""Different content should produce different hash."""
|
||||
key1 = frozenset([("node_1", "value_a")])
|
||||
key2 = frozenset([("node_1", "value_b")])
|
||||
|
||||
hash1 = serialize_cache_key(key1)
|
||||
hash2 = serialize_cache_key(key2)
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_returns_bytes(self):
|
||||
"""Should return bytes (SHA256 digest)."""
|
||||
key = frozenset([("test", 123)])
|
||||
result = serialize_cache_key(key)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) == 32 # SHA256 produces 32 bytes
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Complex nested structures should hash deterministically."""
|
||||
# Note: frozensets can only contain hashable types, so we use
|
||||
# nested frozensets of tuples to represent dict-like structures
|
||||
key = frozenset([
|
||||
("node_1", frozenset([
|
||||
("input_a", ("tuple", "value")),
|
||||
("input_b", frozenset([("nested", "dict")])),
|
||||
])),
|
||||
("node_2", frozenset([
|
||||
("param", 42),
|
||||
])),
|
||||
])
|
||||
|
||||
# Hash twice to verify determinism
|
||||
hash1 = serialize_cache_key(key)
|
||||
hash2 = serialize_cache_key(key)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_dict_in_cache_key(self):
|
||||
"""Dicts passed directly to serialize_cache_key should work."""
|
||||
# This tests the _canonicalize function's ability to handle dicts
|
||||
key = {"node_1": {"input": "value"}, "node_2": 42}
|
||||
|
||||
hash1 = serialize_cache_key(key)
|
||||
hash2 = serialize_cache_key(key)
|
||||
|
||||
assert hash1 == hash2
|
||||
assert isinstance(hash1, bytes)
|
||||
assert len(hash1) == 32
|
||||
|
||||
|
||||
class TestContainsNan:
|
||||
"""Test contains_nan utility function."""
|
||||
|
||||
def test_nan_float_detected(self):
|
||||
"""NaN floats should be detected."""
|
||||
assert contains_nan(float('nan')) is True
|
||||
|
||||
def test_regular_float_not_nan(self):
|
||||
"""Regular floats should not be detected as NaN."""
|
||||
assert contains_nan(3.14) is False
|
||||
assert contains_nan(0.0) is False
|
||||
assert contains_nan(-1.5) is False
|
||||
|
||||
def test_infinity_not_nan(self):
|
||||
"""Infinity is not NaN."""
|
||||
assert contains_nan(float('inf')) is False
|
||||
assert contains_nan(float('-inf')) is False
|
||||
|
||||
def test_nan_in_list(self):
|
||||
"""NaN in list should be detected."""
|
||||
assert contains_nan([1, 2, float('nan'), 4]) is True
|
||||
assert contains_nan([1, 2, 3, 4]) is False
|
||||
|
||||
def test_nan_in_tuple(self):
|
||||
"""NaN in tuple should be detected."""
|
||||
assert contains_nan((1, float('nan'))) is True
|
||||
assert contains_nan((1, 2, 3)) is False
|
||||
|
||||
def test_nan_in_frozenset(self):
|
||||
"""NaN in frozenset should be detected."""
|
||||
assert contains_nan(frozenset([1, float('nan')])) is True
|
||||
assert contains_nan(frozenset([1, 2, 3])) is False
|
||||
|
||||
def test_nan_in_dict_value(self):
|
||||
"""NaN in dict value should be detected."""
|
||||
assert contains_nan({"key": float('nan')}) is True
|
||||
assert contains_nan({"key": 42}) is False
|
||||
|
||||
def test_nan_in_nested_structure(self):
|
||||
"""NaN in deeply nested structure should be detected."""
|
||||
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
||||
assert contains_nan(nested) is True
|
||||
|
||||
def test_non_numeric_types(self):
|
||||
"""Non-numeric types should not be NaN."""
|
||||
assert contains_nan("string") is False
|
||||
assert contains_nan(None) is False
|
||||
assert contains_nan(True) is False
|
||||
|
||||
|
||||
class TestEstimateValueSize:
|
||||
"""Test estimate_value_size utility function."""
|
||||
|
||||
def test_empty_outputs(self):
|
||||
"""Empty outputs should have zero size."""
|
||||
value = CacheValue(outputs=[])
|
||||
assert estimate_value_size(value) == 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _torch_available(),
|
||||
reason="PyTorch not available"
|
||||
)
|
||||
def test_tensor_size_estimation(self):
|
||||
"""Tensor size should be estimated correctly."""
|
||||
import torch
|
||||
|
||||
# 1000 float32 elements = 4000 bytes
|
||||
tensor = torch.zeros(1000, dtype=torch.float32)
|
||||
value = CacheValue(outputs=[[tensor]])
|
||||
|
||||
size = estimate_value_size(value)
|
||||
assert size == 4000
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _torch_available(),
|
||||
reason="PyTorch not available"
|
||||
)
|
||||
def test_nested_tensor_in_dict(self):
|
||||
"""Tensors nested in dicts should be counted."""
|
||||
import torch
|
||||
|
||||
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
value = CacheValue(outputs=[[{"samples": tensor}]])
|
||||
|
||||
size = estimate_value_size(value)
|
||||
assert size == 400
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test cache provider registration and retrieval."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear providers before each test."""
|
||||
clear_cache_providers()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clear providers after each test."""
|
||||
clear_cache_providers()
|
||||
|
||||
def test_register_provider(self):
|
||||
"""Provider should be registered successfully."""
|
||||
provider = MockCacheProvider()
|
||||
register_cache_provider(provider)
|
||||
|
||||
assert has_cache_providers() is True
|
||||
providers = get_cache_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0] is provider
|
||||
|
||||
def test_unregister_provider(self):
|
||||
"""Provider should be unregistered successfully."""
|
||||
provider = MockCacheProvider()
|
||||
register_cache_provider(provider)
|
||||
unregister_cache_provider(provider)
|
||||
|
||||
assert has_cache_providers() is False
|
||||
|
||||
def test_multiple_providers(self):
|
||||
"""Multiple providers can be registered."""
|
||||
provider1 = MockCacheProvider()
|
||||
provider2 = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider1)
|
||||
register_cache_provider(provider2)
|
||||
|
||||
providers = get_cache_providers()
|
||||
assert len(providers) == 2
|
||||
|
||||
def test_duplicate_registration_ignored(self):
|
||||
"""Registering same provider twice should be ignored."""
|
||||
provider = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider)
|
||||
register_cache_provider(provider) # Should be ignored
|
||||
|
||||
providers = get_cache_providers()
|
||||
assert len(providers) == 1
|
||||
|
||||
def test_clear_providers(self):
|
||||
"""clear_cache_providers should remove all providers."""
|
||||
provider1 = MockCacheProvider()
|
||||
provider2 = MockCacheProvider()
|
||||
|
||||
register_cache_provider(provider1)
|
||||
register_cache_provider(provider2)
|
||||
clear_cache_providers()
|
||||
|
||||
assert has_cache_providers() is False
|
||||
assert len(get_cache_providers()) == 0
|
||||
|
||||
|
||||
class TestCacheContext:
|
||||
"""Test CacheContext dataclass."""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""CacheContext should be created with all fields."""
|
||||
context = CacheContext(
|
||||
prompt_id="prompt-123",
|
||||
node_id="node-456",
|
||||
class_type="KSampler",
|
||||
cache_key=frozenset([("test", "value")]),
|
||||
cache_key_bytes=b"hash_bytes",
|
||||
)
|
||||
|
||||
assert context.prompt_id == "prompt-123"
|
||||
assert context.node_id == "node-456"
|
||||
assert context.class_type == "KSampler"
|
||||
assert context.cache_key == frozenset([("test", "value")])
|
||||
assert context.cache_key_bytes == b"hash_bytes"
|
||||
|
||||
|
||||
class TestCacheValue:
|
||||
"""Test CacheValue dataclass."""
|
||||
|
||||
def test_value_creation(self):
|
||||
"""CacheValue should be created with outputs."""
|
||||
outputs = [[{"samples": "tensor_data"}]]
|
||||
value = CacheValue(outputs=outputs)
|
||||
|
||||
assert value.outputs == outputs
|
||||
|
||||
|
||||
class MockCacheProvider(CacheProvider):
|
||||
"""Mock cache provider for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.lookups = []
|
||||
self.stores = []
|
||||
|
||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
self.lookups.append(context)
|
||||
return None
|
||||
|
||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
self.stores.append((context, value))
|
||||
Reference in New Issue
Block a user