fix: send_progress_text unicasts to client_id instead of broadcasting

- Default sid to self.client_id when not explicitly provided, matching
  every other WS message dispatch (executing, executed, progress_state, etc.)
- Previously sid=None caused broadcast to all connected clients
- Format signature per ruff, remove redundant comments
- Add unit tests for routing, legacy format, and new prompt_id format

Amp-Thread-ID: https://ampcode.com/threads/T-019ca3ce-c530-75dd-8d68-349e745a022e
This commit is contained in:
bymyself
2026-03-04 20:40:46 +00:00
parent 83df2a88bd
commit d74dfd2570
2 changed files with 217 additions and 7 deletions

View File

@@ -1233,7 +1233,11 @@ class PromptServer():
return json_data
def send_progress_text(
self, text: Union[bytes, bytearray, str], node_id: str, prompt_id: Optional[str] = None, sid=None
self,
text: Union[bytes, bytearray, str],
node_id: str,
prompt_id: Optional[str] = None,
sid=None,
):
"""Send a progress text message to the client via WebSocket.
@@ -1251,15 +1255,15 @@ class PromptServer():
text = text.encode("utf-8")
node_id_bytes = str(node_id).encode("utf-8")
# When prompt_id is provided and client supports the new format,
# prepend prompt_id as a length-prefixed field before node_id
# Auto-resolve sid to the currently executing client
target_sid = sid if sid is not None else self.client_id
# When prompt_id is available and client supports the new format,
# prepend prompt_id as a length-prefixed field before node_id
if prompt_id and feature_flags.supports_feature(
self.sockets_metadata, target_sid, "supports_progress_text_metadata"
):
prompt_id_bytes = prompt_id.encode("utf-8")
# Pack prompt_id length as a 4-byte unsigned integer, followed by prompt_id bytes,
# then node_id length as a 4-byte unsigned integer, followed by node_id bytes, then text
message = (
struct.pack(">I", len(prompt_id_bytes))
+ prompt_id_bytes
@@ -1268,7 +1272,6 @@ class PromptServer():
+ text
)
else:
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
self.send_sync(BinaryEventTypes.TEXT, message, sid)
self.send_sync(BinaryEventTypes.TEXT, message, target_sid)

View File

@@ -0,0 +1,207 @@
"""Tests for send_progress_text routing and binary format logic.
These tests verify:
1. sid defaults to client_id (unicast) instead of None (broadcast)
2. Legacy binary format when prompt_id absent or client unsupported
3. New binary format with prompt_id when client supports the feature flag
"""
import struct
from comfy_api import feature_flags
# ---------------------------------------------------------------------------
# Helpers replicate the packing logic so we can assert on the wire format
# ---------------------------------------------------------------------------
def _unpack_legacy(message: bytes):
"""Unpack a legacy progress_text binary message -> (node_id, text)."""
offset = 0
node_id_len = struct.unpack_from(">I", message, offset)[0]
offset += 4
node_id = message[offset : offset + node_id_len].decode("utf-8")
offset += node_id_len
text = message[offset:].decode("utf-8")
return node_id, text
def _unpack_with_prompt_id(message: bytes):
"""Unpack new format -> (prompt_id, node_id, text)."""
offset = 0
prompt_id_len = struct.unpack_from(">I", message, offset)[0]
offset += 4
prompt_id = message[offset : offset + prompt_id_len].decode("utf-8")
offset += prompt_id_len
node_id_len = struct.unpack_from(">I", message, offset)[0]
offset += 4
node_id = message[offset : offset + node_id_len].decode("utf-8")
offset += node_id_len
text = message[offset:].decode("utf-8")
return prompt_id, node_id, text
# ---------------------------------------------------------------------------
# Minimal stub that mirrors send_progress_text logic from server.py
# We can't import server.py directly (it pulls in torch via nodes.py),
# so we replicate the method body here. If the implementation changes,
# these tests should be updated in tandem.
# ---------------------------------------------------------------------------
class _StubServer:
"""Stub that captures send_sync calls and runs the real packing logic."""
def __init__(self, client_id=None, sockets_metadata=None):
self.client_id = client_id
self.sockets_metadata = sockets_metadata or {}
self.sent = [] # list of (event, data, sid)
def send_sync(self, event, data, sid=None):
self.sent.append((event, data, sid))
def send_progress_text(self, text, node_id, prompt_id=None, sid=None):
if isinstance(text, str):
text = text.encode("utf-8")
node_id_bytes = str(node_id).encode("utf-8")
target_sid = sid if sid is not None else self.client_id
if prompt_id and feature_flags.supports_feature(
self.sockets_metadata, target_sid, "supports_progress_text_metadata"
):
prompt_id_bytes = prompt_id.encode("utf-8")
message = (
struct.pack(">I", len(prompt_id_bytes))
+ prompt_id_bytes
+ struct.pack(">I", len(node_id_bytes))
+ node_id_bytes
+ text
)
else:
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
self.send_sync(3, message, target_sid) # 3 == BinaryEventTypes.TEXT
# ===========================================================================
# Routing tests
# ===========================================================================
class TestSendProgressTextRouting:
"""Verify sid resolution: defaults to client_id, overridable via sid param."""
def test_defaults_to_client_id_when_sid_not_provided(self):
server = _StubServer(client_id="active-client-123")
server.send_progress_text("hello", "node1")
_, _, sid = server.sent[0]
assert sid == "active-client-123"
def test_explicit_sid_overrides_client_id(self):
server = _StubServer(client_id="active-client-123")
server.send_progress_text("hello", "node1", sid="explicit-sid")
_, _, sid = server.sent[0]
assert sid == "explicit-sid"
def test_broadcasts_when_no_client_id_and_no_sid(self):
server = _StubServer(client_id=None)
server.send_progress_text("hello", "node1")
_, _, sid = server.sent[0]
assert sid is None
# ===========================================================================
# Legacy format tests
# ===========================================================================
class TestSendProgressTextLegacyFormat:
"""Verify legacy binary format: [4B node_id_len][node_id][text]."""
def test_legacy_format_no_prompt_id(self):
server = _StubServer(client_id="c1")
server.send_progress_text("some text", "node-42")
_, data, _ = server.sent[0]
node_id, text = _unpack_legacy(data)
assert node_id == "node-42"
assert text == "some text"
def test_legacy_format_when_client_unsupported(self):
server = _StubServer(
client_id="c1",
sockets_metadata={"c1": {"feature_flags": {}}},
)
server.send_progress_text("text", "node1", prompt_id="prompt-abc")
_, data, _ = server.sent[0]
node_id, text = _unpack_legacy(data)
assert node_id == "node1"
assert text == "text"
def test_bytes_input_preserved(self):
server = _StubServer(client_id="c1")
server.send_progress_text(b"raw bytes", "node1")
_, data, _ = server.sent[0]
node_id, text = _unpack_legacy(data)
assert text == "raw bytes"
# ===========================================================================
# New format tests
# ===========================================================================
class TestSendProgressTextNewFormat:
"""Verify new format: [4B prompt_id_len][prompt_id][4B node_id_len][node_id][text]."""
def test_includes_prompt_id_when_supported(self):
server = _StubServer(
client_id="c1",
sockets_metadata={
"c1": {"feature_flags": {"supports_progress_text_metadata": True}}
},
)
server.send_progress_text("progress!", "node-7", prompt_id="prompt-xyz")
_, data, _ = server.sent[0]
prompt_id, node_id, text = _unpack_with_prompt_id(data)
assert prompt_id == "prompt-xyz"
assert node_id == "node-7"
assert text == "progress!"
def test_new_format_with_explicit_sid(self):
server = _StubServer(
client_id=None,
sockets_metadata={
"my-sid": {"feature_flags": {"supports_progress_text_metadata": True}}
},
)
server.send_progress_text("txt", "n1", prompt_id="p1", sid="my-sid")
_, data, sid = server.sent[0]
assert sid == "my-sid"
prompt_id, node_id, text = _unpack_with_prompt_id(data)
assert prompt_id == "p1"
assert node_id == "n1"
assert text == "txt"
# ===========================================================================
# Feature flag tests
# ===========================================================================
class TestProgressTextFeatureFlag:
"""Verify the supports_progress_text_metadata flag exists in server features."""
def test_flag_in_server_features(self):
features = feature_flags.get_server_features()
assert "supports_progress_text_metadata" in features
assert features["supports_progress_text_metadata"] is True