mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-09 14:09:58 +00:00
Compare commits
18 Commits
dev/Combo-
...
deepme987/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cd2730b25 | ||
|
|
4b1444fc7a | ||
|
|
8cbbea8f6a | ||
|
|
13917b3880 | ||
|
|
f21f6b2212 | ||
|
|
eb0686bbb6 | ||
|
|
5de94e70ec | ||
|
|
76b75f3ad7 | ||
|
|
0c63b4f6e3 | ||
|
|
7d437687c2 | ||
|
|
e2ddf28d78 | ||
|
|
076639fed9 | ||
|
|
d4351f77f8 | ||
|
|
9837dd368a | ||
|
|
62ec9a3238 | ||
|
|
b20cb7892e | ||
|
|
b9b24d425b | ||
|
|
d731cb6ae1 |
36
.github/workflows/release-stable-all.yml
vendored
36
.github/workflows/release-stable-all.yml
vendored
@@ -20,29 +20,12 @@ jobs:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "11"
|
||||
python_patch: "12"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu128:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu128"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu128"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu128"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
@@ -76,3 +59,20 @@ jobs:
|
||||
rel_extra_name: ""
|
||||
test_release: false
|
||||
secrets: inherit
|
||||
|
||||
release_xpu:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release Intel XPU"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "xpu"
|
||||
python_minor: "13"
|
||||
python_patch: "12"
|
||||
rel_name: "intel"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
@@ -61,6 +61,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- NOTE: There are many more models supported than the list below, if you want to see what is supported see our templates list inside ComfyUI.
|
||||
- Image Models
|
||||
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
@@ -136,7 +137,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
||||
- Builds a new release using the latest stable core version
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
- Every 2+ weeks frontend updates are merged into the core repository
|
||||
- Features are frozen for the upcoming core release
|
||||
- Development continues for the next release cycle
|
||||
|
||||
@@ -275,7 +276,7 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu132```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
@@ -7,7 +11,6 @@ if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
@@ -43,6 +46,7 @@ class NodeReplaceManager:
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
import nodes
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
@@ -94,6 +98,60 @@ class NodeReplaceManager:
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None):
|
||||
"""Load node_replacements.json from a custom node directory and register replacements.
|
||||
|
||||
Custom node authors can ship a node_replacements.json file in their repo root
|
||||
to define node replacements declaratively. The file format matches the output
|
||||
of NodeReplace.as_dict(), keyed by old_node_id.
|
||||
|
||||
Fail-open: all errors are logged and skipped so a malformed file never
|
||||
prevents the custom node from loading.
|
||||
"""
|
||||
replacements_path = os.path.join(module_dir, "node_replacements.json")
|
||||
if not os.path.isfile(replacements_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(replacements_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.")
|
||||
return
|
||||
|
||||
if _node_replace_class is None:
|
||||
from comfy_api.latest._io import NodeReplace
|
||||
_node_replace_class = NodeReplace
|
||||
|
||||
count = 0
|
||||
for old_node_id, replacements in data.items():
|
||||
if not isinstance(replacements, list):
|
||||
logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.")
|
||||
continue
|
||||
for entry in replacements:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
new_node_id = entry.get("new_node_id", "")
|
||||
if not new_node_id:
|
||||
logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.")
|
||||
continue
|
||||
self.register(_node_replace_class(
|
||||
new_node_id=new_node_id,
|
||||
old_node_id=entry.get("old_node_id", old_node_id),
|
||||
old_widget_ids=entry.get("old_widget_ids"),
|
||||
input_mapping=entry.get("input_mapping"),
|
||||
output_mapping=entry.get("output_mapping"),
|
||||
))
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json")
|
||||
except json.JSONDecodeError as e:
|
||||
logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}")
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
|
||||
@@ -3,12 +3,9 @@ from ..diffusionmodules.openaimodel import Timestep
|
||||
import torch
|
||||
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
||||
def __init__(self, *args, timestep_dim=256, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if clip_stats_path is None:
|
||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||
else:
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||
self.time_embed = Timestep(timestep_dim)
|
||||
|
||||
@@ -1745,6 +1745,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
226
comfy_api_nodes/apis/wan.py
Normal file
226
comfy_api_nodes/apis/wan.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Text2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
|
||||
|
||||
class Image2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||
|
||||
|
||||
class Text2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Image2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
img_url: str = Field(...)
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Reference2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
reference_video_urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class Txt2ImageParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Image2ImageParametersField(BaseModel):
|
||||
size: str | None = Field(None)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Text2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Image2VideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Reference2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
shot_type: str = Field("single")
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2ImageInputField = Field(...)
|
||||
parameters: Txt2ImageParametersField = Field(...)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Image2ImageInputField = Field(...)
|
||||
parameters: Image2ImageParametersField = Field(...)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2VideoInputField = Field(...)
|
||||
parameters: Text2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Image2VideoInputField = Field(...)
|
||||
parameters: Image2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Reference2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Reference2VideoInputField = Field(...)
|
||||
parameters: Reference2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27MediaItem(BaseModel):
|
||||
type: str = Field(...)
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class Wan27ReferenceVideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
media: list[Wan27MediaItem] = Field(...)
|
||||
|
||||
|
||||
class Wan27ReferenceVideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int = Field(5, ge=2, le=10)
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27ReferenceVideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Wan27ReferenceVideoInputField = Field(...)
|
||||
parameters: Wan27ReferenceVideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27ImageToVideoInputField(BaseModel):
|
||||
prompt: str | None = Field(None)
|
||||
negative_prompt: str | None = Field(None)
|
||||
media: list[Wan27MediaItem] = Field(...)
|
||||
|
||||
|
||||
class Wan27ImageToVideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
duration: int = Field(5, ge=2, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27ImageToVideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Wan27ImageToVideoInputField = Field(...)
|
||||
parameters: Wan27ImageToVideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27VideoEditInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
media: list[Wan27MediaItem] = Field(...)
|
||||
|
||||
|
||||
class Wan27VideoEditParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int = Field(0)
|
||||
audio_setting: str = Field("auto")
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27VideoEditTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Wan27VideoEditInputField = Field(...)
|
||||
parameters: Wan27VideoEditParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27Text2VideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int = Field(5, ge=2, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2VideoInputField = Field(...)
|
||||
parameters: Wan27Text2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class TaskCreationOutputField(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
output: TaskCreationOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
code: str | None = Field(None, description="Error code for the failed request.")
|
||||
message: str | None = Field(None, description="Details about the failed request.")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
url: str | None = Field(None)
|
||||
code: str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
results: list[TaskResult] | None = Field(None)
|
||||
|
||||
|
||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
video_url: str | None = Field(None)
|
||||
code: str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusResponse(BaseModel):
|
||||
output: ImageTaskStatusOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
class VideoTaskStatusResponse(BaseModel):
|
||||
output: VideoTaskStatusOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
File diff suppressed because it is too large
Load Diff
6
nodes.py
6
nodes.py
@@ -2228,6 +2228,12 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
|
||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||
|
||||
# Only load node_replacements.json from directory-based custom nodes (proper packs).
|
||||
# Single-file .py nodes share a parent dir, so checking there would be incorrect.
|
||||
if os.path.isdir(module_path):
|
||||
from server import PromptServer
|
||||
PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name)
|
||||
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.39
|
||||
comfyui-workflow-templates==0.9.44
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@@ -146,6 +146,10 @@ def is_loopback(host):
|
||||
def create_origin_only_middleware():
|
||||
@web.middleware
|
||||
async def origin_only_middleware(request: web.Request, handler):
|
||||
if 'Sec-Fetch-Site' in request.headers:
|
||||
sec_fetch_site = request.headers['Sec-Fetch-Site']
|
||||
if sec_fetch_site == 'cross-site':
|
||||
return web.Response(status=403)
|
||||
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||||
#in that case the Host and Origin hostnames won't match
|
||||
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||||
|
||||
217
tests/test_node_replacements_json.py
Normal file
217
tests/test_node_replacements_json.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Tests for NodeReplaceManager.load_from_json — auto-registration of
|
||||
node_replacements.json from custom node directories."""
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
|
||||
|
||||
class SimpleNodeReplace:
|
||||
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import)."""
|
||||
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||
input_mapping=None, output_mapping=None):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
class TestLoadFromJson(unittest.TestCase):
|
||||
"""Test auto-registration of node_replacements.json from custom node directories."""
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.manager = NodeReplaceManager()
|
||||
|
||||
def _write_json(self, data):
|
||||
path = os.path.join(self.tmpdir, "node_replacements.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _load(self):
|
||||
self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace)
|
||||
|
||||
def test_no_file_does_nothing(self):
|
||||
"""No node_replacements.json — should silently do nothing."""
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_empty_object(self):
|
||||
"""Empty {} — should do nothing."""
|
||||
self._write_json({})
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_single_replacement(self):
|
||||
"""Single replacement entry registers correctly."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": [{"new_id": "model", "old_id": "ckpt_name"}],
|
||||
"output_mapping": [{"new_idx": 0, "old_idx": 0}],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertIn("OldNode", result)
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
entry = result["OldNode"][0]
|
||||
self.assertEqual(entry["new_node_id"], "NewNode")
|
||||
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||
self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}])
|
||||
self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}])
|
||||
|
||||
def test_multiple_replacements(self):
|
||||
"""Multiple old_node_ids each with entries."""
|
||||
self._write_json({
|
||||
"NodeA": [{"new_node_id": "NodeB", "old_node_id": "NodeA"}],
|
||||
"NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}],
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertIn("NodeA", result)
|
||||
self.assertIn("NodeC", result)
|
||||
|
||||
def test_multiple_alternatives_for_same_node(self):
|
||||
"""Multiple replacement options for the same old node."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
{"new_node_id": "AltA", "old_node_id": "OldNode"},
|
||||
{"new_node_id": "AltB", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 2)
|
||||
|
||||
def test_null_mappings(self):
|
||||
"""Null input/output mappings (trivial replacement)."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": None,
|
||||
"output_mapping": None,
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertIsNone(entry["input_mapping"])
|
||||
self.assertIsNone(entry["output_mapping"])
|
||||
|
||||
def test_old_node_id_defaults_to_key(self):
|
||||
"""If old_node_id is missing from entry, uses the dict key."""
|
||||
self._write_json({
|
||||
"OldNode": [{"new_node_id": "NewNode"}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||
|
||||
def test_invalid_json_skips(self):
|
||||
"""Invalid JSON file — should warn and skip, not crash."""
|
||||
path = os.path.join(self.tmpdir, "node_replacements.json")
|
||||
with open(path, "w") as f:
|
||||
f.write("{invalid json")
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_non_object_json_skips(self):
|
||||
"""JSON array instead of object — should warn and skip."""
|
||||
self._write_json([1, 2, 3])
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_non_list_value_skips(self):
|
||||
"""Value is not a list — should warn and skip that key."""
|
||||
self._write_json({
|
||||
"OldNode": "not a list",
|
||||
"GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}],
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertNotIn("OldNode", result)
|
||||
self.assertIn("GoodNode", result)
|
||||
|
||||
def test_with_old_widget_ids(self):
|
||||
"""old_widget_ids are passed through."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"old_widget_ids": ["width", "height"],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(entry["old_widget_ids"], ["width", "height"])
|
||||
|
||||
def test_set_value_in_input_mapping(self):
|
||||
"""input_mapping with set_value entries."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": [
|
||||
{"new_id": "method", "set_value": "lanczos"},
|
||||
{"new_id": "size", "old_id": "dimension"},
|
||||
],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(len(entry["input_mapping"]), 2)
|
||||
|
||||
def test_missing_new_node_id_skipped(self):
|
||||
"""Entry without new_node_id is skipped."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
{"old_node_id": "OldNode"},
|
||||
{"new_node_id": "", "old_node_id": "OldNode"},
|
||||
{"new_node_id": "ValidNew", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew")
|
||||
|
||||
def test_non_dict_entry_skipped(self):
|
||||
"""Non-dict entries in the list are silently skipped."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
"not a dict",
|
||||
{"new_node_id": "NewNode", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
|
||||
def test_has_replacement_after_load(self):
|
||||
"""Manager reports has_replacement correctly after JSON load."""
|
||||
self._write_json({
|
||||
"OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}],
|
||||
})
|
||||
self.assertFalse(self.manager.has_replacement("OldNode"))
|
||||
self._load()
|
||||
self.assertTrue(self.manager.has_replacement("OldNode"))
|
||||
self.assertFalse(self.manager.has_replacement("UnknownNode"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user