From 82745e382bfadf8e7df7e54ed13a90ba7fcfcfc8 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 10 Feb 2026 15:19:06 +0200 Subject: [PATCH] fix(api-nodes): retry on connection errors during polling instead of aborting --- comfy_api_nodes/util/client.py | 48 ++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 8a1259506..e1d58a56d 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -149,6 +149,7 @@ async def poll_op( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + max_consecutive_poll_failures: int = 10, ) -> M: raw = await poll_op_raw( cls, @@ -169,6 +170,7 @@ async def poll_op( estimated_duration=estimated_duration, cancel_endpoint=cancel_endpoint, cancel_timeout=cancel_timeout, + max_consecutive_poll_failures=max_consecutive_poll_failures, ) if not isinstance(raw, dict): raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") @@ -246,6 +248,7 @@ async def poll_op_raw( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + max_consecutive_poll_failures: int = 10, ) -> dict[str, Any]: """ Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, @@ -253,6 +256,10 @@ async def poll_op_raw( Uses default complete, failed and queued states assumption. + If individual poll requests fail due to connection issues (DNS, network), the loop continues + retrying up to ``max_consecutive_poll_failures`` times before giving up. The remote task is + likely still running, so transient network hiccups should not kill the entire operation. + Returns the final JSON response from the poll endpoint. """ completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) @@ -291,6 +298,7 @@ async def poll_op_raw( logging.debug("Polling ticker exited: %s", exc) ticker_task = asyncio.create_task(_ticker()) + consecutive_poll_failures = 0 try: while consumed_attempts < max_poll_attempts: try: @@ -325,6 +333,46 @@ async def poll_op_raw( monitor_progress=False, ) raise + except (LocalNetworkError, ApiServerError) as poll_err: + # Connection-level failure. The remote task is likely still running - keep polling. + consecutive_poll_failures += 1 + consumed_attempts += 1 + if consecutive_poll_failures >= max_consecutive_poll_failures: + logging.error( + "Poll request failed %d consecutive times, giving up: %s", + consecutive_poll_failures, + poll_err, + ) + raise + logging.warning( + "Poll request failed due to connection error (%d/%d consecutive). " + "Task is likely still running on the server. Retrying in %.1fs: %s", + consecutive_poll_failures, + max_consecutive_poll_failures, + poll_interval * 2, + poll_err, + ) + state.status_label = "Reconnecting" + try: + await sleep_with_interrupt(poll_interval * 2, cls, None, None, None) + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + continue + + consecutive_poll_failures = 0 try: status = _normalize_status_value(status_extractor(resp_json))