Merge remote-tracking branch 'origin/master' into feat/cache-provider-api

This commit is contained in:
Deep Mehta
2026-03-03 15:11:17 -08:00
264 changed files with 20689 additions and 1608 deletions

View File

@@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
}

View File

@@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from . import _io_public as io
from . import _ui_public as ui
from comfy_execution.utils import get_executing_context
@@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -105,6 +114,7 @@ class Types:
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
File3D = File3D
class Caching:

View File

@@ -34,6 +34,21 @@ class VideoInput(ABC):
"""
pass
@abstractmethod
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = False,
) -> VideoInput | None:
"""
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
Returns:
A new VideoInput, or None if the result would have negative duration
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without

View File

@@ -6,6 +6,7 @@ from typing import Optional
from .._input import AudioInput, VideoInput
import av
import io
import itertools
import json
import numpy as np
import math
@@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
@@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
self.__start_time = start_time
self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
raw_duration = self._get_raw_duration()
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
if self.__duration:
return min(self.__duration, duration_from_start)
return duration_from_start
def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
@@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for packet in frame_iterator:
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
@@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
# 1. Prefer the frames field if available and usable
if (
video_stream.frames
and video_stream.frames > 0
and not self.__start_time
and not self.__duration
):
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
raw_duration = float(video_stream.duration * video_stream.time_base)
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
duration_seconds = min(self.__duration, duration_from_start)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
frame_count = 1
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
frame_iterator = (
container.decode(video_stream)
if video_stream.codec.capabilities & 0x100
else container.demux(video_stream)
)
for frame in frame_iterator:
if frame.pts >= start_pts:
break
else:
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
for frame in frame_iterator:
if frame.pts >= end_pts:
break
frame_count += 1
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
video_stream = self._get_first_video_stream(container)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
# Get video frames
frames = []
for frame in container.decode(video=0):
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
@@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
container.seek(start_pts, stream=video_stream)
# Use last stream for consistency
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
audio_frames = []
resample = av.audio.resampler.AudioResampler(format='fltp').resample
frames = itertools.chain.from_iterable(
map(resample, container.decode(audio_stream))
)
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
if has_first_frame:
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
if self.__duration:
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
})
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
metadata: Optional[dict] = None,
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if self.__start_time or self.__duration:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
path, format=format, codec=codec, metadata=metadata
)
streams = container.streams
@@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
if len(container.streams.video):
return container.streams.video[0]
raise ValueError(f"No video stream found in file '{self.__file}'")
def as_trimmed(
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
) -> VideoInput | None:
trimmed = VideoFromFile(
self.get_stream_source(),
start_time=start_time + self.__start_time,
duration=duration,
)
if trimmed.get_duration() < duration and strict_duration:
return None
return trimmed
class VideoFromComponents(VideoInput):
@@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
frame_rate=self.__components.frame_rate,
)
def save_to(
@@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
metadata: Optional[dict] = None,
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = True,
) -> VideoInput | None:
if self.get_duration() < start_time + duration:
return None
#TODO Consider tracking duration and trimming at time of save?
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)

View File

@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, SVG as _SVG
from ._util import MESH, VOXEL, SVG as _SVG, File3D
class FolderType(str, Enum):
@@ -73,8 +73,15 @@ class RemoteOptions:
class NumberDisplay(str, Enum):
number = "number"
slider = "slider"
gradient_slider = "gradientslider"
class ControlAfterGenerate(str, Enum):
fixed = "fixed"
increment = "increment"
decrement = "decrement"
randomize = "randomize"
class _ComfyType(ABC):
Type = Any
io_type: str = None
@@ -263,7 +270,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -290,13 +297,15 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
self.max = max
self.step = step
self.round = round
self.display_mode = display_mode
self.gradient_stops = gradient_stops
self.default: float
def as_dict(self):
@@ -306,6 +315,7 @@ class Float(ComfyTypeIO):
"step": self.step,
"round": self.round,
"display": self.display_mode,
"gradient_stops": self.gradient_stops,
})
@comfytype(io_type="STRING")
@@ -345,7 +355,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
@@ -389,7 +399,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True
@@ -667,6 +677,49 @@ class Voxel(ComfyTypeIO):
class Mesh(ComfyTypeIO):
Type = MESH
@comfytype(io_type="FILE_3D")
class File3DAny(ComfyTypeIO):
"""General 3D file type - accepts any supported 3D format."""
Type = File3D
@comfytype(io_type="FILE_3D_GLB")
class File3DGLB(ComfyTypeIO):
"""GLB format 3D file - binary glTF, best for web and cross-platform."""
Type = File3D
@comfytype(io_type="FILE_3D_GLTF")
class File3DGLTF(ComfyTypeIO):
"""GLTF format 3D file - JSON-based glTF with external resources."""
Type = File3D
@comfytype(io_type="FILE_3D_FBX")
class File3DFBX(ComfyTypeIO):
"""FBX format 3D file - best for game engines and animation."""
Type = File3D
@comfytype(io_type="FILE_3D_OBJ")
class File3DOBJ(ComfyTypeIO):
"""OBJ format 3D file - simple geometry format."""
Type = File3D
@comfytype(io_type="FILE_3D_STL")
class File3DSTL(ComfyTypeIO):
"""STL format 3D file - best for 3D printing."""
Type = File3D
@comfytype(io_type="FILE_3D_USDZ")
class File3DUSDZ(ComfyTypeIO):
"""USDZ format 3D file - Apple AR format."""
Type = File3D
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
if TYPE_CHECKING:
@@ -1146,6 +1199,47 @@ class ImageCompare(ComfyTypeI):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLOR")
class Color(ComfyTypeIO):
Type = str
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, advanced: bool=None, default: str="#ffffff"):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
self.default: str
def as_dict(self):
return super().as_dict()
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
x: int
y: int
width: int
height: int
Type = BoundingBoxDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
self.force_input = force_input
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
def as_dict(self):
d = super().as_dict()
if self.component:
d["component"] = self.component
if self.force_input is not None:
d["forceInput"] = self.force_input
return d
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -1234,6 +1328,7 @@ class Hidden(str, Enum):
class NodeInfoV1:
input: dict=None
input_order: dict[str, list[str]]=None
is_input_list: bool=None
output: list[str]=None
output_is_list: list[bool]=None
output_name: list[str]=None
@@ -1251,23 +1346,7 @@ class NodeInfoV1:
api_node: bool=None
price_badge: dict | None = None
search_aliases: list[str]=None
@dataclass
class NodeInfoV3:
input: dict=None
output: dict=None
hidden: list[str]=None
name: str=None
display_name: str=None
description: str=None
python_module: Any = None
category: str=None
output_node: bool=None
deprecated: bool=None
experimental: bool=None
dev_only: bool=None
api_node: bool=None
price_badge: dict | None = None
essentials_category: str=None
@dataclass
@@ -1389,6 +1468,8 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
def validate(self):
'''Validate the schema:
@@ -1477,6 +1558,7 @@ class Schema:
info = NodeInfoV1(
input=input,
input_order={key: list(value.keys()) for (key, value) in input.items()},
is_input_list=self.is_input_list,
output=output,
output_is_list=output_is_list,
output_name=output_name,
@@ -1494,40 +1576,7 @@ class Schema:
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
search_aliases=self.search_aliases if self.search_aliases else None,
)
return info
def get_v3_info(self, cls) -> NodeInfoV3:
input_dict = {}
output_dict = {}
hidden_list = []
# TODO: make sure dynamic types will be handled correctly
if self.inputs:
for input in self.inputs:
add_to_dict_v3(input, input_dict)
if self.outputs:
for output in self.outputs:
add_to_dict_v3(output, output_dict)
if self.hidden:
for hidden in self.hidden:
hidden_list.append(hidden.value)
info = NodeInfoV3(
input=input_dict,
output=output_dict,
hidden=hidden_list,
name=self.node_id,
display_name=self.display_name,
description=self.description,
category=self.category,
output_node=self.is_output_node,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
essentials_category=self.essentials_category,
)
return info
@@ -1585,9 +1634,6 @@ def add_to_dict_v1(i: Input, d: dict):
as_dict.pop("optional", None)
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
class DynamicPathsDefaultValue:
EMPTY_DICT = "empty_dict"
@@ -1748,13 +1794,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
# set hidden
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
return type_clone
@final
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
info = schema.get_v3_info(cls)
return asdict(info)
#############################################
# V1 Backwards Compatibility code
#--------------------------------------------
@@ -2032,11 +2071,74 @@ class _UIOutput(ABC):
...
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
"ControlAfterGenerate",
"comfytype",
"Custom",
@@ -2082,6 +2184,13 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
"File3DAny",
"File3DGLB",
"File3DGLTF",
"File3DFBX",
"File3DOBJ",
"File3DSTL",
"File3DUSDZ",
"Hooks",
"HookKeyframes",
"TimestepsRange",
@@ -2099,6 +2208,7 @@ __all__ = [
"AnyType",
"MultiType",
"Tracks",
"Color",
# Dynamic Types
"MatchType",
"DynamicCombo",
@@ -2107,14 +2217,14 @@ __all__ = [
"HiddenHolder",
"Hidden",
"NodeInfoV1",
"NodeInfoV3",
"Schema",
"ComfyNode",
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
"V3Data",
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"NodeReplace",
]

View File

@@ -1,5 +1,5 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH
from .geometry_types import VOXEL, MESH, File3D
from .image_types import SVG
__all__ = [
@@ -9,5 +9,6 @@ __all__ = [
"VideoComponents",
"VOXEL",
"MESH",
"File3D",
"SVG",
]

View File

@@ -1,3 +1,8 @@
import shutil
from io import BytesIO
from pathlib import Path
from typing import IO
import torch
@@ -10,3 +15,75 @@ class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces
class File3D:
"""Class representing a 3D file from a file path or binary stream.
Supports both disk-backed (file path) and memory-backed (BytesIO) storage.
"""
def __init__(self, source: str | IO[bytes], file_format: str = ""):
self._source = source
self._format = file_format or self._infer_format()
def _infer_format(self) -> str:
if isinstance(self._source, str):
return Path(self._source).suffix.lstrip(".").lower()
return ""
@property
def format(self) -> str:
return self._format
@format.setter
def format(self, value: str) -> None:
self._format = value.lstrip(".").lower() if value else ""
@property
def is_disk_backed(self) -> bool:
return isinstance(self._source, str)
def get_source(self) -> str | IO[bytes]:
if isinstance(self._source, str):
return self._source
if hasattr(self._source, "seek"):
self._source.seek(0)
return self._source
def get_data(self) -> BytesIO:
if isinstance(self._source, str):
with open(self._source, "rb") as f:
result = BytesIO(f.read())
return result
if hasattr(self._source, "seek"):
self._source.seek(0)
if isinstance(self._source, BytesIO):
return self._source
return BytesIO(self._source.read())
def save_to(self, path: str) -> str:
dest = Path(path)
dest.parent.mkdir(parents=True, exist_ok=True)
if isinstance(self._source, str):
if Path(self._source).resolve() != dest.resolve():
shutil.copy2(self._source, dest)
else:
if hasattr(self._source, "seek"):
self._source.seek(0)
with open(dest, "wb") as f:
f.write(self._source.read())
return str(dest)
def get_bytes(self) -> bytes:
if isinstance(self._source, str):
return Path(self._source).read_bytes()
if hasattr(self._source, "seek"):
self._source.seek(0)
return self._source.read()
def __repr__(self) -> str:
if isinstance(self._source, str):
return f"File3D(source={self._source!r}, format={self._format!r})"
return f"File3D(<stream>, format={self._format!r})"