Changed execute instance method to EXECUTE class method, added countermeasures to avoid state leaks, ready ability to add extra params to clean class type clone

This commit is contained in:
Jedrzej Kosinski
2025-06-05 04:12:44 -07:00
parent a7f515e913
commit d79a3cf990
3 changed files with 95 additions and 36 deletions

View File

@@ -532,6 +532,30 @@ class SchemaV3:
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
class Serializer:
def __init_subclass__(cls, io_type: IO | str, **kwargs):
cls.io_type = io_type
super().__init_subclass__(**kwargs)
@classmethod
def serialize(cls, o: Any) -> str:
pass
@classmethod
def deserialize(cls, s: str) -> Any:
pass
def prepare_class_clone(c: ComfyNodeV3 | type[ComfyNodeV3]) -> type[ComfyNodeV3]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNodeV3] = c if is_class(c) else type(c)
type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {})
# TODO: what parameters should be carried over?
type_clone.SCHEMA = c_type.SCHEMA
# TODO: add anything we would want to expose inside node's EXECUTE function
return type_clone
class classproperty(object):
def __init__(self, f):
self.f = f
@@ -543,6 +567,43 @@ class ComfyNodeV3(ABC):
"""Common base class for all V3 nodes."""
RELATIVE_PYTHON_MODULE = None
SCHEMA = None
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod
@abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3:
"""
Override this function with one that returns a SchemaV3 instance.
"""
return None
DEFINE_SCHEMA = None
@classmethod
@abstractmethod
def EXECUTE(cls, **kwargs) -> NodeOutput:
pass
EXECUTE = None
@classmethod
def GET_SERIALIZERS(cls) -> list[Serializer]:
return []
def __init__(self):
self.__class__.VALIDATE_CLASS()
@classmethod
def VALIDATE_CLASS(cls):
if not callable(cls.DEFINE_SCHEMA):
raise Exception(f"No DEFINE_SCHEMA function was defined for node class {cls.__name__}.")
if not callable(cls.EXECUTE):
raise Exception(f"No execute function was defined for node class {cls.__name__}.")
#############################################
# V1 Backwards Compatibility code
#--------------------------------------------
@@ -623,7 +684,7 @@ class ComfyNodeV3(ABC):
cls.GET_SCHEMA()
return cls._OUTPUT_TOOLTIPS
FUNCTION = "execute"
FUNCTION = "EXECUTE"
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
@@ -642,6 +703,7 @@ class ComfyNodeV3(ABC):
@classmethod
def GET_SCHEMA(cls) -> SchemaV3:
cls.VALIDATE_CLASS()
schema = cls.DEFINE_SCHEMA()
if cls._DESCRIPTION is None:
cls._DESCRIPTION = schema.description
@@ -674,7 +736,7 @@ class ComfyNodeV3(ABC):
cls._RETURN_NAMES = output_name
cls._OUTPUT_IS_LIST = output_is_list
cls._OUTPUT_TOOLTIPS = output_tooltips
cls.SCHEMA = schema
return schema
@classmethod
@@ -716,31 +778,6 @@ class ComfyNodeV3(ABC):
#--------------------------------------------
#############################################
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod
@abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3:
"""
Override this function with one that returns a SchemaV3 instance.
"""
return None
DEFINE_SCHEMA = None
def __init__(self):
if self.DEFINE_SCHEMA is None:
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
@abstractmethod
def execute(self, **kwargs) -> NodeOutput:
pass
# class ReturnedInputs:
# def __init__(self):
# pass
@@ -857,19 +894,20 @@ class TestNode(ComfyNodeV3):
def DEFINE_SCHEMA(cls):
return cls.SCHEMA
def execute(**kwargs):
def EXECUTE(**kwargs):
pass
if __name__ == "__main__":
print("hello there")
inputs: list[InputV3] = [
IntegerInput("tessfes", widgetType=IO.STRING),
IntegerInput("my_int"),
CustomInput("xyz", "XYZ"),
CustomInput("model1", "MODEL_M"),
ImageInput("my_image"),
FloatInput("my_float"),
MultitypedInput("my_inputs", [CustomType("MODEL_M"), CustomType("XYZ")]),
MultitypedInput("my_inputs", [StringInput, CustomType("MODEL_M"), CustomType("XYZ")]),
]
outputs: list[OutputV3] = [