Compare commits

..

2 Commits

9 changed files with 21 additions and 93 deletions

View File

@@ -193,8 +193,6 @@ class HiddenInputTypeDict(TypedDict):
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: NotRequired[Literal["DYNPROMPT"]]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
prompt_id: NotRequired[Literal["PROMPT_ID"]]
"""PROMPT_ID is the unique identifier of the current prompt/job being executed. Useful for associating progress updates with specific jobs."""
class InputTypeDict(TypedDict):

View File

@@ -485,7 +485,7 @@ class WanVAE(nn.Module):
iter_ = 1 + (t - 1) // 4
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.encoder)
feat_map = [None] * count_conv3d(self.decoder)
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]

View File

@@ -167,15 +167,17 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (want_requant and len(fns) == 0 or update_weight):
if (isinstance(orig, QuantizedTensor) and
(want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
x = y
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
if update_weight:
orig.copy_(y)
for f in fns:
@@ -615,8 +617,7 @@ def fp8_linear(self, input):
if input.ndim != 2:
return None
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)

View File

@@ -12,7 +12,6 @@ from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"supports_progress_text_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,

View File

@@ -1269,16 +1269,9 @@ class V3Data(TypedDict):
'When True, the value of the dynamic input will be in the format (value, path_key).'
class HiddenHolder:
"""Holds hidden input values resolved during node execution.
Hidden inputs are special values automatically provided by the execution
engine (e.g., node ID, prompt data, authentication tokens) rather than
being connected by the user in the graph.
"""
def __init__(self, unique_id: str, prompt: Any,
extra_pnginfo: Any, dynprompt: Any,
auth_token_comfy_org: str, api_key_comfy_org: str,
prompt_id: str = None, **kwargs):
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
self.unique_id = unique_id
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
self.prompt = prompt
@@ -1291,8 +1284,6 @@ class HiddenHolder:
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
self.api_key_comfy_org = api_key_comfy_org
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
self.prompt_id = prompt_id
"""PROMPT_ID is the unique identifier of the current prompt/job being executed."""
def __getattr__(self, key: str):
'''If hidden variable not found, return None.'''
@@ -1300,14 +1291,6 @@ class HiddenHolder:
@classmethod
def from_dict(cls, d: dict | None):
"""Create a HiddenHolder from a dictionary of hidden input values.
Args:
d: Dictionary mapping Hidden enum values to their resolved values.
Returns:
A new HiddenHolder instance with values populated from the dict.
"""
if d is None:
d = {}
return cls(
@@ -1317,7 +1300,6 @@ class HiddenHolder:
dynprompt=d.get(Hidden.dynprompt, None),
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
prompt_id=d.get(Hidden.prompt_id, None),
)
@classmethod
@@ -1340,8 +1322,6 @@ class Hidden(str, Enum):
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
api_key_comfy_org = "API_KEY_COMFY_ORG"
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
prompt_id = "PROMPT_ID"
"""PROMPT_ID is the unique identifier of the current prompt/job being executed. Useful for associating progress updates with specific jobs."""
@dataclass

View File

@@ -17,7 +17,6 @@ from pydantic import BaseModel
from comfy import utils
from comfy_api.latest import IO
from comfy_execution.utils import get_executing_context
from server import PromptServer
from . import request_logger
@@ -437,17 +436,6 @@ def _display_text(
status: str | int | None = None,
price: float | None = None,
) -> None:
"""Send a progress text message to the client for display on a node.
Assembles status, price, and text lines, then sends them via WebSocket.
Automatically retrieves the current prompt_id from the execution context.
Args:
node_cls: The ComfyNode class sending the progress text.
text: Optional text content to display.
status: Optional status string or code to display.
price: Optional price in dollars to display as credits.
"""
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
@@ -458,9 +446,7 @@ def _display_text(
if text is not None:
display_lines.append(text)
if display_lines:
ctx = get_executing_context()
prompt_id = ctx.prompt_id if ctx is not None else None
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls), prompt_id=prompt_id)
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
def _display_time_progress(

View File

@@ -566,7 +566,7 @@ class GetImageSize(IO.ComfyNode):
IO.Int.Output(display_name="height"),
IO.Int.Output(display_name="batch_size"),
],
hidden=[IO.Hidden.unique_id, IO.Hidden.prompt_id],
hidden=[IO.Hidden.unique_id],
)
@classmethod
@@ -577,7 +577,7 @@ class GetImageSize(IO.ComfyNode):
# Send progress text to display size on the node
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id, prompt_id=cls.hidden.prompt_id)
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
return IO.NodeOutput(width, height, batch_size)

View File

@@ -149,7 +149,7 @@ class CacheSet:
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}, prompt_id=None):
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {}
hidden_inputs_v3 = {}
@@ -196,8 +196,6 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
if io.Hidden.api_key_comfy_org.name in hidden:
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
if io.Hidden.prompt_id.name in hidden:
hidden_inputs_v3[io.Hidden.prompt_id] = prompt_id
else:
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
@@ -214,8 +212,6 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
if h[x] == "PROMPT_ID":
input_data_all[x] = [prompt_id]
v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
@@ -473,7 +469,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data, prompt_id=prompt_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)

View File

@@ -1233,45 +1233,13 @@ class PromptServer():
return json_data
def send_progress_text(
self,
text: Union[bytes, bytearray, str],
node_id: str,
prompt_id: Optional[str] = None,
sid=None,
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
):
"""Send a progress text message to the client via WebSocket.
Encodes the text as a binary message with length-prefixed node_id. When
prompt_id is provided and the client supports the ``supports_progress_text_metadata``
feature flag, the prompt_id is prepended as an additional length-prefixed field.
Args:
text: The progress text content to send.
node_id: The unique identifier of the node sending the progress.
prompt_id: Optional prompt/job identifier to associate the message with.
sid: Optional session ID to target a specific client.
"""
if isinstance(text, str):
text = text.encode("utf-8")
node_id_bytes = str(node_id).encode("utf-8")
# Auto-resolve sid to the currently executing client
target_sid = sid if sid is not None else self.client_id
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
# When prompt_id is available and client supports the new format,
# prepend prompt_id as a length-prefixed field before node_id
if prompt_id and feature_flags.supports_feature(
self.sockets_metadata, target_sid, "supports_progress_text_metadata"
):
prompt_id_bytes = prompt_id.encode("utf-8")
message = (
struct.pack(">I", len(prompt_id_bytes))
+ prompt_id_bytes
+ struct.pack(">I", len(node_id_bytes))
+ node_id_bytes
+ text
)
else:
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
self.send_sync(BinaryEventTypes.TEXT, message, target_sid)
self.send_sync(BinaryEventTypes.TEXT, message, sid)