From 3278ff75fd5c267ea4cf1bbe087fdd6f12c3f6b1 Mon Sep 17 00:00:00 2001 From: Johnpaul Date: Mon, 16 Feb 2026 23:53:50 +0100 Subject: [PATCH] feat: replay execution context on WS reconnect Store current execution state (prompt_id, cached nodes, executed nodes, outputs to execute) on the server instance during execution. On WS reconnect, replay execution_start, execution_cached, progress_state, and executing events so the frontend can restore progress tracking. Refactor ProgressRegistry to expose get_serialized_state() and reuse it in WebUIProgressHandler._send_progress_state(). --- comfy_execution/progress.py | 34 ++++++++++++++++++-------------- execution.py | 9 +++++++++ server.py | 39 ++++++++++++++++++++++++++++++++++--- 3 files changed, 64 insertions(+), 18 deletions(-) diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index f951a3350..1ffb72161 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -164,21 +164,7 @@ class WebUIProgressHandler(ProgressHandler): if self.server_instance is None: return - # Only send info for non-pending nodes - active_nodes = { - node_id: { - "value": state["value"], - "max": state["max"], - "state": state["state"].value, - "node_id": node_id, - "prompt_id": prompt_id, - "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), - "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), - "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), - } - for node_id, state in nodes.items() - if state["state"] != NodeState.Pending - } + active_nodes = self.registry.get_serialized_state() # Send a combined progress_state message with all node states # Include client_id to ensure message is only sent to the initiating client @@ -314,6 +300,24 @@ class ProgressRegistry: if handler.enabled: handler.finish_handler(node_id, entry, self.prompt_id) + def get_serialized_state(self) -> Dict[str, dict]: + """Return current node progress as a dict suitable for WS progress_state.""" + active: Dict[str, dict] = {} + for nid, state in self.nodes.items(): + if state["state"] == NodeState.Pending: + continue + active[nid] = { + "value": state["value"], + "max": state["max"], + "state": state["state"].value, + "node_id": nid, + "prompt_id": self.prompt_id, + "display_node_id": self.dynprompt.get_display_node_id(nid), + "parent_node_id": self.dynprompt.get_parent_node_id(nid), + "real_node_id": self.dynprompt.get_real_node_id(nid), + } + return active + def reset_handlers(self) -> None: """Reset all handlers""" for handler in self.handlers.values(): diff --git a/execution.py b/execution.py index 896862c6b..c9db011d9 100644 --- a/execution.py +++ b/execution.py @@ -701,6 +701,8 @@ class PromptExecutor: else: self.server.client_id = None + self.server.current_prompt_id = prompt_id + self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) @@ -722,10 +724,13 @@ class PromptExecutor: self.add_message("execution_cached", { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) + self.server.current_cached_nodes = cached_nodes pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results ui_node_outputs = {} executed = set() + self.server.current_executed_nodes = executed + self.server.current_outputs_to_execute = list(execute_outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): @@ -762,6 +767,10 @@ class PromptExecutor: "meta": meta_outputs, } self.server.last_node_id = None + self.server.current_prompt_id = None + self.server.current_outputs_to_execute = None + self.server.current_cached_nodes = None + self.server.current_executed_nodes = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() diff --git a/server.py b/server.py index 2300393b2..f25fe0c6b 100644 --- a/server.py +++ b/server.py @@ -242,6 +242,10 @@ class PromptServer(): self.routes = routes self.last_node_id = None self.client_id = None + self.current_prompt_id = None + self.current_outputs_to_execute = None + self.current_cached_nodes = None + self.current_executed_nodes = None self.on_prompt_handlers = [] @@ -264,9 +268,38 @@ class PromptServer(): try: # Send initial state to the new client await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) - # On reconnect if we are the currently executing client send the current node - if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + # On reconnect if we are the currently executing client, replay catch-up events + if self.client_id == sid and self.current_prompt_id is not None: + await self.send("execution_start", { + "prompt_id": self.current_prompt_id, + "timestamp": int(time.time() * 1000), + "outputs_to_execute": self.current_outputs_to_execute or [], + "executed_node_ids": list(self.current_executed_nodes) if self.current_executed_nodes else [], + }, sid) + + if self.current_cached_nodes: + await self.send("execution_cached", { + "nodes": self.current_cached_nodes, + "prompt_id": self.current_prompt_id, + "timestamp": int(time.time() * 1000), + }, sid) + + from comfy_execution.progress import get_progress_state + progress = get_progress_state() + if progress.prompt_id == self.current_prompt_id: + active_nodes = progress.get_serialized_state() + if active_nodes: + await self.send("progress_state", { + "prompt_id": self.current_prompt_id, + "nodes": active_nodes, + }, sid) + + if self.last_node_id is not None: + await self.send("executing", { + "node": self.last_node_id, + "display_node": self.last_node_id, + "prompt_id": self.current_prompt_id, + }, sid) # Flag to track if we've received the first message first_message = True