diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 050031dc0..e9248090e 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -326,11 +326,14 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None, + min_length: int=None, max_length: int=None): super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts + self.min_length = min_length + self.max_length = max_length self.default: str def as_dict(self): @@ -338,6 +341,8 @@ class String(ComfyTypeIO): "multiline": self.multiline, "placeholder": self.placeholder, "dynamicPrompts": self.dynamic_prompts, + "minLength": self.min_length, + "maxLength": self.max_length, }) @comfytype(io_type="COMBO") diff --git a/execution.py b/execution.py index 7ccdbf93e..ed6fcd7b4 100644 --- a/execution.py +++ b/execution.py @@ -81,7 +81,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await resolve_map_node_over_list_results(is_changed) @@ -213,7 +213,35 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] v3_data["hidden_inputs"] = hidden_inputs_v3 - return input_data_all, missing_keys, v3_data + return input_data_all, missing_keys, v3_data, valid_inputs + +def validate_resolved_inputs(input_data_all, class_def, valid_inputs): + """Validate resolved input values against schema constraints. + + This is needed because validate_inputs() only sees direct widget values. + Linked inputs aren't resolved during validate_inputs(), so this runs after resolution to catch any violations. + """ + for x, values in input_data_all.items(): + input_type, input_category, extra_info = get_input_info(class_def, x, valid_inputs) + if input_type != "STRING": + continue + min_length = extra_info.get("minLength") + max_length = extra_info.get("maxLength") + if min_length is None and max_length is None: + continue + for val in values: + if val is None or not isinstance(val, str): + continue + if min_length is not None and len(val) < min_length: + raise ValueError( + f"Input '{x}': value length {len(val)} is shorter than " + f"minimum length of {min_length}" + ) + if max_length is not None and len(val) > max_length: + raise ValueError( + f"Input '{x}': value length {len(val)} is longer than " + f"maximum length of {max_length}" + ) map_node_over_list = None #Don't hook this please @@ -469,7 +497,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -498,6 +526,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, execution_list.make_input_strong_link(unique_id, i) return (ExecutionResult.PENDING, None, None) + validate_resolved_inputs(input_data_all, class_def, valid_inputs) + def execution_block_cb(block): if block.message is not None: mes = { @@ -940,6 +970,34 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue + if input_type == "STRING": + if "minLength" in extra_info and len(val) < extra_info["minLength"]: + error = { + "type": "value_shorter_than_min_length", + "message": "Value length {} shorter than min length of {}".format(len(val), extra_info["minLength"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if "maxLength" in extra_info and len(val) > extra_info["maxLength"]: + error = { + "type": "value_longer_than_max_length", + "message": "Value length {} longer than max length of {}".format(len(val), extra_info["maxLength"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if isinstance(input_type, list) or input_type == io.Combo.io_type: if input_type == io.Combo.io_type: combo_options = extra_info.get("options", []) @@ -971,7 +1029,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..9959f7b11 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -1011,3 +1011,49 @@ class TestExecution: """Test getting a non-existent job returns 404""" job = client.get_job("nonexistent-job-id") assert job is None, "Non-existent job should return None" + + + @pytest.mark.parametrize("text, expect_error", [ + ("hello", False), # 5 chars, within [3, 10] + ("abc", False), # 3 chars, exact min boundary + ("abcdefghij", False), # 10 chars, exact max boundary + ("ab", True), # 2 chars, below min + ("abcdefghijk", True), # 11 chars, above max + ("", True), # 0 chars, below min + ]) + def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder): + """Test minLength/maxLength validation for direct widget values (validate_inputs path).""" + g = builder + node = g.node("StubStringWithLength", text=text) + g.node("SaveImage", images=node.out(0)) + if expect_error: + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.run(g) + assert exc_info.value.code == 400 + else: + client.run(g) + + + @pytest.mark.parametrize("text, expect_error", [ + ("hello", False), # 5 chars, within [3, 10] + ("abc", False), # 3 chars, exact min boundary + ("abcdefghij", False), # 10 chars, exact max boundary + ("ab", True), # 2 chars, below min + ("abcdefghijk", True), # 11 chars, above max + ("", True), # 0 chars, below min + ]) + def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder): + """Test minLength/maxLength validation for linked inputs (validate_resolved_inputs path).""" + g = builder + str_node = g.node("StubStringOutput", value=text) + node = g.node("StubStringWithLength", text=str_node.out(0)) + g.node("SaveImage", images=node.out(0)) + + if expect_error: + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" + else: + client.run(g) diff --git a/tests/execution/testing_nodes/testing-pack/stubs.py b/tests/execution/testing_nodes/testing-pack/stubs.py index a1df87529..7e34028ed 100644 --- a/tests/execution/testing_nodes/testing-pack/stubs.py +++ b/tests/execution/testing_nodes/testing-pack/stubs.py @@ -113,12 +113,48 @@ class StubFloat: def stub_float(self, value): return (value,) +class StubStringOutput: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("STRING", {"default": ""}), + }, + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "stub_string" + + CATEGORY = "Testing/Stub Nodes" + + def stub_string(self, value): + return (value,) + +class StubStringWithLength: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_string_with_length" + + CATEGORY = "Testing/Stub Nodes" + + def stub_string_with_length(self, text): + return (torch.zeros(1, 64, 64, 3),) + TEST_STUB_NODE_CLASS_MAPPINGS = { "StubImage": StubImage, "StubConstantImage": StubConstantImage, "StubMask": StubMask, "StubInt": StubInt, "StubFloat": StubFloat, + "StubStringOutput": StubStringOutput, + "StubStringWithLength": StubStringWithLength, } TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubImage": "Stub Image", @@ -126,4 +162,6 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubMask": "Stub Mask", "StubInt": "Stub Int", "StubFloat": "Stub Float", + "StubStringOutput": "Stub String Output", + "StubStringWithLength": "Stub String With Length", }