mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-12 15:27:43 +00:00
Compare commits
8 Commits
range-type
...
feat/api-n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
045221026d | ||
|
|
a2840e7552 | ||
|
|
a134423890 | ||
|
|
b920bdd77d | ||
|
|
5410ed34f5 | ||
|
|
e6be419a30 | ||
|
|
3d4aca8084 | ||
|
|
2d861fb146 |
2
.ci/windows_intel_base_files/run_intel_gpu.bat
Executable file
2
.ci/windows_intel_base_files/run_intel_gpu.bat
Executable file
@@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
pause
|
||||
@@ -182,7 +182,7 @@
|
||||
]
|
||||
},
|
||||
"widgets_values": [
|
||||
50
|
||||
0
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -316,7 +316,7 @@
|
||||
"step": 1
|
||||
},
|
||||
"widgets_values": [
|
||||
30
|
||||
0
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -90,7 +90,7 @@ class HeatmapHead(torch.nn.Module):
|
||||
origin_max = np.max(hm[k])
|
||||
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
||||
dr[border:-border, border:-border] = hm[k].copy()
|
||||
dr = gaussian_filter(dr, sigma=2.0)
|
||||
dr = gaussian_filter(dr, sigma=2.0, truncate=2.5)
|
||||
hm[k] = dr[border:-border, border:-border].copy()
|
||||
cur_max = np.max(hm[k])
|
||||
if cur_max > 0:
|
||||
|
||||
@@ -9,7 +9,6 @@ from comfy_api.latest._input import (
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
RangeInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -22,5 +21,4 @@ __all__ = [
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
"RangeInput",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .range_types import RangeInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@@ -13,5 +12,4 @@ __all__ = [
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
"RangeInput",
|
||||
]
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RangeInput:
|
||||
"""Represents a levels/range adjustment: input range [min, max] with
|
||||
optional midpoint (gamma control).
|
||||
|
||||
Generates a 1D LUT identical to GIMP's levels mapping:
|
||||
1. Normalize input to [0, 1] using [min, max]
|
||||
2. Apply gamma correction: pow(value, 1/gamma)
|
||||
3. Clamp to [0, 1]
|
||||
|
||||
The midpoint field is a position in [0, 1] representing where the
|
||||
midtone falls within [min, max]. It maps to gamma via:
|
||||
gamma = -log2(midpoint)
|
||||
So midpoint=0.5 → gamma=1.0 (linear).
|
||||
"""
|
||||
|
||||
def __init__(self, min_val: float, max_val: float, midpoint: float | None = None):
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.midpoint = midpoint
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> RangeInput:
|
||||
if isinstance(data, RangeInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
return RangeInput(
|
||||
min_val=float(data.get("min", 0.0)),
|
||||
max_val=float(data.get("max", 1.0)),
|
||||
midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None,
|
||||
)
|
||||
raise TypeError(f"Cannot convert {type(data)} to RangeInput")
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table mapping [0, 1] input through this
|
||||
levels adjustment.
|
||||
|
||||
The LUT maps normalized input values (0..1) to output values (0..1),
|
||||
matching the GIMP levels formula.
|
||||
"""
|
||||
xs = np.linspace(0.0, 1.0, size, dtype=np.float64)
|
||||
|
||||
in_range = self.max_val - self.min_val
|
||||
if abs(in_range) < 1e-10:
|
||||
return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64)
|
||||
|
||||
# Normalize: map [min, max] → [0, 1]
|
||||
result = (xs - self.min_val) / in_range
|
||||
result = np.clip(result, 0.0, 1.0)
|
||||
|
||||
# Gamma correction from midpoint
|
||||
if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5:
|
||||
gamma = max(-math.log2(self.midpoint), 0.001)
|
||||
inv_gamma = 1.0 / gamma
|
||||
mask = result > 0
|
||||
result[mask] = np.power(result[mask], inv_gamma)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
mid = f", midpoint={self.midpoint}" if self.midpoint is not None else ""
|
||||
return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})"
|
||||
@@ -1266,43 +1266,6 @@ class Histogram(ComfyTypeIO):
|
||||
Type = list[int]
|
||||
|
||||
|
||||
@comfytype(io_type="RANGE")
|
||||
class Range(ComfyTypeIO):
|
||||
from comfy_api.input import RangeInput
|
||||
if TYPE_CHECKING:
|
||||
Type = RangeInput
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None,
|
||||
display: str=None,
|
||||
gradient_stops: list=None,
|
||||
show_midpoint: bool=None,
|
||||
midpoint_scale: str=None,
|
||||
value_min: float=None,
|
||||
value_max: float=None,
|
||||
advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = {"min": 0.0, "max": 1.0}
|
||||
self.display = display
|
||||
self.gradient_stops = gradient_stops
|
||||
self.show_midpoint = show_midpoint
|
||||
self.midpoint_scale = midpoint_scale
|
||||
self.value_min = value_min
|
||||
self.value_max = value_max
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"display": self.display,
|
||||
"gradient_stops": self.gradient_stops,
|
||||
"show_midpoint": self.show_midpoint,
|
||||
"midpoint_scale": self.midpoint_scale,
|
||||
"value_min": self.value_min,
|
||||
"value_max": self.value_max,
|
||||
})
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@@ -2313,6 +2276,5 @@ __all__ = [
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"Range",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -52,6 +52,26 @@ class TaskImageContent(BaseModel):
|
||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||
|
||||
|
||||
class TaskVideoContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskVideoContent(BaseModel):
|
||||
type: str = Field("video_url")
|
||||
video_url: TaskVideoContentUrl = Field(...)
|
||||
role: str = Field("reference_video")
|
||||
|
||||
|
||||
class TaskAudioContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskAudioContent(BaseModel):
|
||||
type: str = Field("audio_url")
|
||||
audio_url: TaskAudioContentUrl = Field(...)
|
||||
role: str = Field("reference_audio")
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
@@ -64,6 +84,17 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
||||
generate_audio: bool | None = Field(...)
|
||||
|
||||
|
||||
class Seedance2TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1)
|
||||
generate_audio: bool | None = Field(None)
|
||||
resolution: str | None = Field(None)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int | None = Field(None, ge=4, le=15)
|
||||
seed: int | None = Field(None, ge=0, le=2147483647)
|
||||
watermark: bool | None = Field(None)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
@@ -77,12 +108,27 @@ class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusUsage(BaseModel):
|
||||
completion_tokens: int = Field(0)
|
||||
total_tokens: int = Field(0)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
content: TaskStatusResult | None = Field(None)
|
||||
usage: TaskStatusUsage | None = Field(None)
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||
("dreamina-seedance-2-0-260128", True): 0.0043,
|
||||
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
|
||||
}
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
@@ -112,6 +158,12 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# Seedance 2.0 reference video pixel count limits per model.
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
|
||||
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
|
||||
@@ -8,16 +8,23 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedance2TaskCreationRequest,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskAudioContent,
|
||||
TaskAudioContentUrl,
|
||||
TaskCreationResponse,
|
||||
TaskImageContent,
|
||||
TaskImageContentUrl,
|
||||
TaskStatusResponse,
|
||||
TaskTextContent,
|
||||
TaskVideoContent,
|
||||
TaskVideoContentUrl,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
)
|
||||
@@ -29,7 +36,10 @@ from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@@ -46,12 +56,56 @@ SEEDREAM_MODELS = {
|
||||
# Long-running tasks endpoints(e.g., video)
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
SEEDANCE_MODELS = {
|
||||
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||
}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
|
||||
"""Validate reference video pixel count against Seedance 2.0 model limits."""
|
||||
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||
if not limits:
|
||||
return
|
||||
try:
|
||||
w, h = video.get_dimensions()
|
||||
except Exception:
|
||||
return
|
||||
pixels = w * h
|
||||
min_px = limits.get("min")
|
||||
max_px = limits.get("max")
|
||||
if min_px and pixels < min_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
|
||||
)
|
||||
if max_px and pixels > max_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
|
||||
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
if rate is None:
|
||||
return None
|
||||
|
||||
def extractor(response: TaskStatusResponse) -> float | None:
|
||||
if response.usage is None:
|
||||
return None
|
||||
return response.usage.total_tokens * 1.43 * rate / 1_000.0
|
||||
|
||||
return extractor
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||
@@ -335,8 +389,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution for the selected model is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
@@ -952,33 +1005,6 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
for i in text_params:
|
||||
if f"--{i} " in prompt:
|
||||
@@ -1040,6 +1066,530 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_text_inputs():
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for video generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||
tooltip="Aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=7,
|
||||
min=4,
|
||||
max=15,
|
||||
step=1,
|
||||
tooltip="Duration of the output video in seconds (4-15).",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=True,
|
||||
tooltip="Enable audio generation for the output video.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2TextToVideoNode",
|
||||
display_name="ByteDance Seedance 2.0 Text to Video",
|
||||
category="api node/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 models based on a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add a watermark to the video.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$m := widgets.model;
|
||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=Seedance2TaskCreationRequest(
|
||||
model=model_id,
|
||||
content=[TaskTextContent(text=model["prompt"])],
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2FirstLastFrameNode",
|
||||
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
|
||||
category="api node/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="First frame image for the video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"last_frame",
|
||||
tooltip="Last frame image for the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add a watermark to the video.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$m := widgets.model;
|
||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
first_frame: Input.Image,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
last_frame: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||
),
|
||||
role="first_frame",
|
||||
),
|
||||
]
|
||||
if last_frame is not None:
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.")
|
||||
),
|
||||
role="last_frame",
|
||||
),
|
||||
)
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=Seedance2TaskCreationRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def _seedance2_reference_inputs():
|
||||
return [
|
||||
*_seedance2_text_inputs(),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("reference_image"),
|
||||
names=[
|
||||
"image_1",
|
||||
"image_2",
|
||||
"image_3",
|
||||
"image_4",
|
||||
"image_5",
|
||||
"image_6",
|
||||
"image_7",
|
||||
"image_8",
|
||||
"image_9",
|
||||
],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("reference_video"),
|
||||
names=["video_1", "video_2", "video_3"],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_audios",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Audio.Input("reference_audio"),
|
||||
names=["audio_1", "audio_2", "audio_3"],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2ReferenceNode",
|
||||
display_name="ByteDance Seedance 2.0 Reference to Video",
|
||||
category="api node/video/ByteDance",
|
||||
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
|
||||
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs()),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs()),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add a watermark to the video.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.resolution", "model.duration"],
|
||||
input_groups=["model.reference_videos"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$m := widgets.model;
|
||||
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
|
||||
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
|
||||
$minVideoFactor := $ceil($dur * 5 / 3);
|
||||
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
|
||||
$maxVideoCost := (15 + $dur) * $rate * $videoPricePer1K / 1000;
|
||||
$hasVideo
|
||||
? {
|
||||
"type": "range_usd",
|
||||
"min_usd": $minVideoCost,
|
||||
"max_usd": $maxVideoCost,
|
||||
"format": {"approximate": true}
|
||||
}
|
||||
: {
|
||||
"type": "usd",
|
||||
"usd": $noVideoCost,
|
||||
"format": {"approximate": true}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
|
||||
reference_images = model.get("reference_images", {})
|
||||
reference_videos = model.get("reference_videos", {})
|
||||
reference_audios = model.get("reference_audios", {})
|
||||
|
||||
if not reference_images and not reference_videos:
|
||||
raise ValueError("At least one reference image or video is required.")
|
||||
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
has_video_input = len(reference_videos) > 0
|
||||
total_video_duration = 0.0
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
video = reference_videos[key]
|
||||
_validate_ref_video_pixels(video, model_id, i)
|
||||
try:
|
||||
dur = video.get_duration()
|
||||
if dur < 1.8:
|
||||
raise ValueError(f"Reference video {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||
total_video_duration += dur
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
if total_video_duration > 15.1:
|
||||
raise ValueError(f"Total reference video duration is {total_video_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||
|
||||
total_audio_duration = 0.0
|
||||
for i, key in enumerate(reference_audios, 1):
|
||||
audio = reference_audios[key]
|
||||
dur = int(audio["waveform"].shape[-1]) / int(audio["sample_rate"])
|
||||
if dur < 1.8:
|
||||
raise ValueError(f"Reference audio {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||
total_audio_duration += dur
|
||||
if total_audio_duration > 15.1:
|
||||
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
]
|
||||
for i, key in enumerate(reference_images, 1):
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(
|
||||
cls,
|
||||
image=reference_images[key],
|
||||
wait_label=f"Uploading image {i}",
|
||||
),
|
||||
),
|
||||
role="reference_image",
|
||||
),
|
||||
)
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
content.append(
|
||||
TaskVideoContent(
|
||||
video_url=TaskVideoContentUrl(
|
||||
url=await upload_video_to_comfyapi(
|
||||
cls,
|
||||
reference_videos[key],
|
||||
wait_label=f"Uploading video {i}",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
for key in reference_audios:
|
||||
content.append(
|
||||
TaskAudioContent(
|
||||
audio_url=TaskAudioContentUrl(
|
||||
url=await upload_audio_to_comfyapi(
|
||||
cls,
|
||||
reference_audios[key],
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
mime_type="audio/mpeg",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=Seedance2TaskCreationRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -1050,6 +1600,9 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDanceImageToVideoNode,
|
||||
ByteDanceFirstLastFrameNode,
|
||||
ByteDanceImageReferenceNode,
|
||||
ByteDance2TextToVideoNode,
|
||||
ByteDance2FirstLastFrameNode,
|
||||
ByteDance2ReferenceNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -558,7 +558,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$refs := inputGroups["model.reference_images"];
|
||||
$refs := $lookup(inputGroups, "model.reference_images");
|
||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||
{"type":"usd","usd": $price}
|
||||
|
||||
@@ -32,10 +32,12 @@ class RTDETR_detect(io.ComfyNode):
|
||||
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
||||
B, H, W, C = image.shape
|
||||
|
||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
|
||||
results = []
|
||||
for i in range(0, B, 32):
|
||||
batch = image[i:i + 32]
|
||||
image_in = comfy.utils.common_upscale(batch.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
results.extend(model.model.diffusion_model(image_in, (W, H)))
|
||||
|
||||
all_bbox_dicts = []
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
import math
|
||||
import colorsys
|
||||
@@ -410,7 +411,9 @@ class SDPoseDrawKeypoints(io.ComfyNode):
|
||||
pose_outputs.append(canvas)
|
||||
|
||||
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
|
||||
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
|
||||
final_pose_output = torch.from_numpy(pose_outputs_np).to(
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
dtype=comfy.model_management.intermediate_dtype()) / 255.0
|
||||
return io.NodeOutput(final_pose_output)
|
||||
|
||||
class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
@@ -459,6 +462,27 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
|
||||
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
|
||||
|
||||
def _resize_to_model(imgs):
|
||||
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
|
||||
h, w = imgs.shape[-3], imgs.shape[-2]
|
||||
scale = min(model_h / h, model_w / w)
|
||||
sh, sw = int(round(h * scale)), int(round(w * scale))
|
||||
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
|
||||
chw = imgs.permute(0, 3, 1, 2).float()
|
||||
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
|
||||
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
|
||||
return padded.permute(0, 2, 3, 1), scale, pt, pl
|
||||
|
||||
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
|
||||
"""Remap keypoints from model space back to original image space."""
|
||||
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||
invalid = kp[..., 0] < 0
|
||||
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
|
||||
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
|
||||
kp[invalid] = -1
|
||||
return kp
|
||||
|
||||
def _run_on_latent(latent_batch):
|
||||
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
|
||||
nonlocal captured_feat
|
||||
@@ -504,36 +528,19 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
continue
|
||||
|
||||
crop_h_px, crop_w_px = y2 - y1, x2 - x1
|
||||
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
|
||||
|
||||
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
|
||||
scale = min(model_h / crop_h_px, model_w / crop_w_px)
|
||||
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
|
||||
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
|
||||
|
||||
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
|
||||
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
|
||||
|
||||
latent_crop = vae.encode(crop_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_crop)
|
||||
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
|
||||
|
||||
# remove padding offset, undo scale, offset to full-image coordinates.
|
||||
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
|
||||
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
|
||||
|
||||
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
|
||||
img_keypoints.append(kp)
|
||||
img_scores.append(sc)
|
||||
img_scores.append(sc_batch[0])
|
||||
else:
|
||||
# No bboxes for this image – run on the full image
|
||||
latent_img = vae.encode(img)
|
||||
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
|
||||
latent_img = vae.encode(img_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_img)
|
||||
img_keypoints.append(kp_batch[0])
|
||||
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
|
||||
img_scores.append(sc_batch[0])
|
||||
|
||||
all_keypoints.append(img_keypoints)
|
||||
@@ -541,19 +548,16 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
pbar.update(1)
|
||||
|
||||
else: # full-image mode, batched
|
||||
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
|
||||
for batch_start in range(0, total_images, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_images)
|
||||
latent_batch = vae.encode(image[batch_start:batch_end])
|
||||
|
||||
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
|
||||
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
|
||||
latent_batch = vae.encode(batch_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_batch)
|
||||
|
||||
for kp, sc in zip(kp_batch, sc_batch):
|
||||
all_keypoints.append([kp])
|
||||
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
|
||||
all_scores.append([sc])
|
||||
tqdm_pbar.update(1)
|
||||
|
||||
pbar.update(batch_end - batch_start)
|
||||
pbar.update(len(kp_batch))
|
||||
|
||||
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
|
||||
return io.NodeOutput(openpose_frames)
|
||||
|
||||
@@ -6,6 +6,7 @@ import comfy.utils
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.model_management
|
||||
|
||||
try:
|
||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||
@@ -78,13 +79,15 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
tile = 512
|
||||
overlap = 32
|
||||
|
||||
output_device = comfy.model_management.intermediate_device()
|
||||
|
||||
oom = True
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||
oom = False
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
@@ -94,7 +97,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput(s)
|
||||
|
||||
upscale = execute # TODO: remove
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.44
|
||||
comfyui-frontend-package==1.42.10
|
||||
comfyui-workflow-templates==0.9.45
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Reference in New Issue
Block a user