mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-10 15:50:02 +00:00
feat: add minLength/maxLength validation for String inputs
This commit is contained in:
@@ -326,11 +326,14 @@ class String(ComfyTypeIO):
|
||||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None,
|
||||
min_length: int=None, max_length: int=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.default: str
|
||||
|
||||
def as_dict(self):
|
||||
@@ -338,6 +341,8 @@ class String(ComfyTypeIO):
|
||||
"multiline": self.multiline,
|
||||
"placeholder": self.placeholder,
|
||||
"dynamicPrompts": self.dynamic_prompts,
|
||||
"minLength": self.min_length,
|
||||
"maxLength": self.max_length,
|
||||
})
|
||||
|
||||
@comfytype(io_type="COMBO")
|
||||
|
||||
66
execution.py
66
execution.py
@@ -81,7 +81,7 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
@@ -213,7 +213,35 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||
return input_data_all, missing_keys, v3_data
|
||||
return input_data_all, missing_keys, v3_data, valid_inputs
|
||||
|
||||
def validate_resolved_inputs(input_data_all, class_def, valid_inputs):
|
||||
"""Validate resolved input values against schema constraints.
|
||||
|
||||
This is needed because validate_inputs() only sees direct widget values.
|
||||
Linked inputs aren't resolved during validate_inputs(), so this runs after resolution to catch any violations.
|
||||
"""
|
||||
for x, values in input_data_all.items():
|
||||
input_type, input_category, extra_info = get_input_info(class_def, x, valid_inputs)
|
||||
if input_type != "STRING":
|
||||
continue
|
||||
min_length = extra_info.get("minLength")
|
||||
max_length = extra_info.get("maxLength")
|
||||
if min_length is None and max_length is None:
|
||||
continue
|
||||
for val in values:
|
||||
if val is None or not isinstance(val, str):
|
||||
continue
|
||||
if min_length is not None and len(val) < min_length:
|
||||
raise ValueError(
|
||||
f"Input '{x}': value length {len(val)} is shorter than "
|
||||
f"minimum length of {min_length}"
|
||||
)
|
||||
if max_length is not None and len(val) > max_length:
|
||||
raise ValueError(
|
||||
f"Input '{x}': value length {len(val)} is longer than "
|
||||
f"maximum length of {max_length}"
|
||||
)
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@@ -469,7 +497,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -498,6 +526,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
validate_resolved_inputs(input_data_all, class_def, valid_inputs)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
@@ -940,6 +970,34 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if input_type == "STRING":
|
||||
if "minLength" in extra_info and len(val) < extra_info["minLength"]:
|
||||
error = {
|
||||
"type": "value_shorter_than_min_length",
|
||||
"message": "Value length {} shorter than min length of {}".format(len(val), extra_info["minLength"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "maxLength" in extra_info and len(val) > extra_info["maxLength"]:
|
||||
error = {
|
||||
"type": "value_longer_than_max_length",
|
||||
"message": "Value length {} longer than max length of {}".format(len(val), extra_info["maxLength"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||
if input_type == io.Combo.io_type:
|
||||
combo_options = extra_info.get("options", [])
|
||||
@@ -971,7 +1029,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
|
||||
@@ -1011,3 +1011,49 @@ class TestExecution:
|
||||
"""Test getting a non-existent job returns 404"""
|
||||
job = client.get_job("nonexistent-job-id")
|
||||
assert job is None, "Non-existent job should return None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text, expect_error", [
|
||||
("hello", False), # 5 chars, within [3, 10]
|
||||
("abc", False), # 3 chars, exact min boundary
|
||||
("abcdefghij", False), # 10 chars, exact max boundary
|
||||
("ab", True), # 2 chars, below min
|
||||
("abcdefghijk", True), # 11 chars, above max
|
||||
("", True), # 0 chars, below min
|
||||
])
|
||||
def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test minLength/maxLength validation for direct widget values (validate_inputs path)."""
|
||||
g = builder
|
||||
node = g.node("StubStringWithLength", text=text)
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
assert exc_info.value.code == 400
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text, expect_error", [
|
||||
("hello", False), # 5 chars, within [3, 10]
|
||||
("abc", False), # 3 chars, exact min boundary
|
||||
("abcdefghij", False), # 10 chars, exact max boundary
|
||||
("ab", True), # 2 chars, below min
|
||||
("abcdefghijk", True), # 11 chars, above max
|
||||
("", True), # 0 chars, below min
|
||||
])
|
||||
def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test minLength/maxLength validation for linked inputs (validate_resolved_inputs path)."""
|
||||
g = builder
|
||||
str_node = g.node("StubStringOutput", value=text)
|
||||
node = g.node("StubStringWithLength", text=str_node.out(0))
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
|
||||
if expect_error:
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
@@ -113,12 +113,48 @@ class StubFloat:
|
||||
def stub_float(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubStringOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {"default": ""}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "stub_string"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_string(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubStringWithLength:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_string_with_length"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_string_with_length(self, text):
|
||||
return (torch.zeros(1, 64, 64, 3),)
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
"StubStringOutput": StubStringOutput,
|
||||
"StubStringWithLength": StubStringWithLength,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
@@ -126,4 +162,6 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
"StubStringOutput": "Stub String Output",
|
||||
"StubStringWithLength": "Stub String With Length",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user