mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-09 23:30:02 +00:00
fix: send_progress_text unicasts to client_id instead of broadcasting
- Default sid to self.client_id when not explicitly provided, matching every other WS message dispatch (executing, executed, progress_state, etc.) - Previously sid=None caused broadcast to all connected clients - Format signature per ruff, remove redundant comments - Add unit tests for routing, legacy format, and new prompt_id format Amp-Thread-ID: https://ampcode.com/threads/T-019ca3ce-c530-75dd-8d68-349e745a022e
This commit is contained in:
17
server.py
17
server.py
@@ -1233,7 +1233,11 @@ class PromptServer():
|
|||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
def send_progress_text(
|
def send_progress_text(
|
||||||
self, text: Union[bytes, bytearray, str], node_id: str, prompt_id: Optional[str] = None, sid=None
|
self,
|
||||||
|
text: Union[bytes, bytearray, str],
|
||||||
|
node_id: str,
|
||||||
|
prompt_id: Optional[str] = None,
|
||||||
|
sid=None,
|
||||||
):
|
):
|
||||||
"""Send a progress text message to the client via WebSocket.
|
"""Send a progress text message to the client via WebSocket.
|
||||||
|
|
||||||
@@ -1251,15 +1255,15 @@ class PromptServer():
|
|||||||
text = text.encode("utf-8")
|
text = text.encode("utf-8")
|
||||||
node_id_bytes = str(node_id).encode("utf-8")
|
node_id_bytes = str(node_id).encode("utf-8")
|
||||||
|
|
||||||
# When prompt_id is provided and client supports the new format,
|
# Auto-resolve sid to the currently executing client
|
||||||
# prepend prompt_id as a length-prefixed field before node_id
|
|
||||||
target_sid = sid if sid is not None else self.client_id
|
target_sid = sid if sid is not None else self.client_id
|
||||||
|
|
||||||
|
# When prompt_id is available and client supports the new format,
|
||||||
|
# prepend prompt_id as a length-prefixed field before node_id
|
||||||
if prompt_id and feature_flags.supports_feature(
|
if prompt_id and feature_flags.supports_feature(
|
||||||
self.sockets_metadata, target_sid, "supports_progress_text_metadata"
|
self.sockets_metadata, target_sid, "supports_progress_text_metadata"
|
||||||
):
|
):
|
||||||
prompt_id_bytes = prompt_id.encode("utf-8")
|
prompt_id_bytes = prompt_id.encode("utf-8")
|
||||||
# Pack prompt_id length as a 4-byte unsigned integer, followed by prompt_id bytes,
|
|
||||||
# then node_id length as a 4-byte unsigned integer, followed by node_id bytes, then text
|
|
||||||
message = (
|
message = (
|
||||||
struct.pack(">I", len(prompt_id_bytes))
|
struct.pack(">I", len(prompt_id_bytes))
|
||||||
+ prompt_id_bytes
|
+ prompt_id_bytes
|
||||||
@@ -1268,7 +1272,6 @@ class PromptServer():
|
|||||||
+ text
|
+ text
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
|
||||||
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||||||
|
|
||||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
self.send_sync(BinaryEventTypes.TEXT, message, target_sid)
|
||||||
|
|||||||
207
tests-unit/prompt_server_test/send_progress_text_test.py
Normal file
207
tests-unit/prompt_server_test/send_progress_text_test.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Tests for send_progress_text routing and binary format logic.
|
||||||
|
|
||||||
|
These tests verify:
|
||||||
|
1. sid defaults to client_id (unicast) instead of None (broadcast)
|
||||||
|
2. Legacy binary format when prompt_id absent or client unsupported
|
||||||
|
3. New binary format with prompt_id when client supports the feature flag
|
||||||
|
"""
|
||||||
|
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from comfy_api import feature_flags
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers – replicate the packing logic so we can assert on the wire format
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_legacy(message: bytes):
|
||||||
|
"""Unpack a legacy progress_text binary message -> (node_id, text)."""
|
||||||
|
offset = 0
|
||||||
|
node_id_len = struct.unpack_from(">I", message, offset)[0]
|
||||||
|
offset += 4
|
||||||
|
node_id = message[offset : offset + node_id_len].decode("utf-8")
|
||||||
|
offset += node_id_len
|
||||||
|
text = message[offset:].decode("utf-8")
|
||||||
|
return node_id, text
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_with_prompt_id(message: bytes):
|
||||||
|
"""Unpack new format -> (prompt_id, node_id, text)."""
|
||||||
|
offset = 0
|
||||||
|
prompt_id_len = struct.unpack_from(">I", message, offset)[0]
|
||||||
|
offset += 4
|
||||||
|
prompt_id = message[offset : offset + prompt_id_len].decode("utf-8")
|
||||||
|
offset += prompt_id_len
|
||||||
|
node_id_len = struct.unpack_from(">I", message, offset)[0]
|
||||||
|
offset += 4
|
||||||
|
node_id = message[offset : offset + node_id_len].decode("utf-8")
|
||||||
|
offset += node_id_len
|
||||||
|
text = message[offset:].decode("utf-8")
|
||||||
|
return prompt_id, node_id, text
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Minimal stub that mirrors send_progress_text logic from server.py
|
||||||
|
# We can't import server.py directly (it pulls in torch via nodes.py),
|
||||||
|
# so we replicate the method body here. If the implementation changes,
|
||||||
|
# these tests should be updated in tandem.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _StubServer:
|
||||||
|
"""Stub that captures send_sync calls and runs the real packing logic."""
|
||||||
|
|
||||||
|
def __init__(self, client_id=None, sockets_metadata=None):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.sockets_metadata = sockets_metadata or {}
|
||||||
|
self.sent = [] # list of (event, data, sid)
|
||||||
|
|
||||||
|
def send_sync(self, event, data, sid=None):
|
||||||
|
self.sent.append((event, data, sid))
|
||||||
|
|
||||||
|
def send_progress_text(self, text, node_id, prompt_id=None, sid=None):
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = text.encode("utf-8")
|
||||||
|
node_id_bytes = str(node_id).encode("utf-8")
|
||||||
|
|
||||||
|
target_sid = sid if sid is not None else self.client_id
|
||||||
|
|
||||||
|
if prompt_id and feature_flags.supports_feature(
|
||||||
|
self.sockets_metadata, target_sid, "supports_progress_text_metadata"
|
||||||
|
):
|
||||||
|
prompt_id_bytes = prompt_id.encode("utf-8")
|
||||||
|
message = (
|
||||||
|
struct.pack(">I", len(prompt_id_bytes))
|
||||||
|
+ prompt_id_bytes
|
||||||
|
+ struct.pack(">I", len(node_id_bytes))
|
||||||
|
+ node_id_bytes
|
||||||
|
+ text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||||||
|
|
||||||
|
self.send_sync(3, message, target_sid) # 3 == BinaryEventTypes.TEXT
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# Routing tests
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestSendProgressTextRouting:
|
||||||
|
"""Verify sid resolution: defaults to client_id, overridable via sid param."""
|
||||||
|
|
||||||
|
def test_defaults_to_client_id_when_sid_not_provided(self):
|
||||||
|
server = _StubServer(client_id="active-client-123")
|
||||||
|
server.send_progress_text("hello", "node1")
|
||||||
|
|
||||||
|
_, _, sid = server.sent[0]
|
||||||
|
assert sid == "active-client-123"
|
||||||
|
|
||||||
|
def test_explicit_sid_overrides_client_id(self):
|
||||||
|
server = _StubServer(client_id="active-client-123")
|
||||||
|
server.send_progress_text("hello", "node1", sid="explicit-sid")
|
||||||
|
|
||||||
|
_, _, sid = server.sent[0]
|
||||||
|
assert sid == "explicit-sid"
|
||||||
|
|
||||||
|
def test_broadcasts_when_no_client_id_and_no_sid(self):
|
||||||
|
server = _StubServer(client_id=None)
|
||||||
|
server.send_progress_text("hello", "node1")
|
||||||
|
|
||||||
|
_, _, sid = server.sent[0]
|
||||||
|
assert sid is None
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# Legacy format tests
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestSendProgressTextLegacyFormat:
|
||||||
|
"""Verify legacy binary format: [4B node_id_len][node_id][text]."""
|
||||||
|
|
||||||
|
def test_legacy_format_no_prompt_id(self):
|
||||||
|
server = _StubServer(client_id="c1")
|
||||||
|
server.send_progress_text("some text", "node-42")
|
||||||
|
|
||||||
|
_, data, _ = server.sent[0]
|
||||||
|
node_id, text = _unpack_legacy(data)
|
||||||
|
assert node_id == "node-42"
|
||||||
|
assert text == "some text"
|
||||||
|
|
||||||
|
def test_legacy_format_when_client_unsupported(self):
|
||||||
|
server = _StubServer(
|
||||||
|
client_id="c1",
|
||||||
|
sockets_metadata={"c1": {"feature_flags": {}}},
|
||||||
|
)
|
||||||
|
server.send_progress_text("text", "node1", prompt_id="prompt-abc")
|
||||||
|
|
||||||
|
_, data, _ = server.sent[0]
|
||||||
|
node_id, text = _unpack_legacy(data)
|
||||||
|
assert node_id == "node1"
|
||||||
|
assert text == "text"
|
||||||
|
|
||||||
|
def test_bytes_input_preserved(self):
|
||||||
|
server = _StubServer(client_id="c1")
|
||||||
|
server.send_progress_text(b"raw bytes", "node1")
|
||||||
|
|
||||||
|
_, data, _ = server.sent[0]
|
||||||
|
node_id, text = _unpack_legacy(data)
|
||||||
|
assert text == "raw bytes"
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# New format tests
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestSendProgressTextNewFormat:
|
||||||
|
"""Verify new format: [4B prompt_id_len][prompt_id][4B node_id_len][node_id][text]."""
|
||||||
|
|
||||||
|
def test_includes_prompt_id_when_supported(self):
|
||||||
|
server = _StubServer(
|
||||||
|
client_id="c1",
|
||||||
|
sockets_metadata={
|
||||||
|
"c1": {"feature_flags": {"supports_progress_text_metadata": True}}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
server.send_progress_text("progress!", "node-7", prompt_id="prompt-xyz")
|
||||||
|
|
||||||
|
_, data, _ = server.sent[0]
|
||||||
|
prompt_id, node_id, text = _unpack_with_prompt_id(data)
|
||||||
|
assert prompt_id == "prompt-xyz"
|
||||||
|
assert node_id == "node-7"
|
||||||
|
assert text == "progress!"
|
||||||
|
|
||||||
|
def test_new_format_with_explicit_sid(self):
|
||||||
|
server = _StubServer(
|
||||||
|
client_id=None,
|
||||||
|
sockets_metadata={
|
||||||
|
"my-sid": {"feature_flags": {"supports_progress_text_metadata": True}}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
server.send_progress_text("txt", "n1", prompt_id="p1", sid="my-sid")
|
||||||
|
|
||||||
|
_, data, sid = server.sent[0]
|
||||||
|
assert sid == "my-sid"
|
||||||
|
prompt_id, node_id, text = _unpack_with_prompt_id(data)
|
||||||
|
assert prompt_id == "p1"
|
||||||
|
assert node_id == "n1"
|
||||||
|
assert text == "txt"
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# Feature flag tests
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressTextFeatureFlag:
|
||||||
|
"""Verify the supports_progress_text_metadata flag exists in server features."""
|
||||||
|
|
||||||
|
def test_flag_in_server_features(self):
|
||||||
|
features = feature_flags.get_server_features()
|
||||||
|
assert "supports_progress_text_metadata" in features
|
||||||
|
assert features["supports_progress_text_metadata"] is True
|
||||||
Reference in New Issue
Block a user