mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 12:10:02 +00:00
Compare commits
5 Commits
v0.3.48
...
js/drafts/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0254d9cc11 | ||
|
|
92f9a10782 | ||
|
|
a6a6b615f4 | ||
|
|
50bf72f852 | ||
|
|
46c8311d14 |
@@ -4,9 +4,6 @@ if you have a NVIDIA gpu:
|
||||
|
||||
run_nvidia_gpu.bat
|
||||
|
||||
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
|
||||
|
||||
run_nvidia_gpu_fast_fp16_accumulation.bat
|
||||
|
||||
|
||||
To run it in slow CPU mode:
|
||||
|
||||
40
.github/workflows/check-line-endings.yml
vendored
40
.github/workflows/check-line-endings.yml
vendored
@@ -1,40 +0,0 @@
|
||||
name: Check for Windows Line Endings
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*'] # Trigger on all pull requests to any branch
|
||||
|
||||
jobs:
|
||||
check-line-endings:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history to compare changes
|
||||
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
# Loop through each changed file
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# Check if the file exists and is a text file
|
||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
||||
# Check for CRLF line endings
|
||||
if grep -UP '\r$' "$FILE"; then
|
||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
||||
CRLF_FOUND=true
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Exit with error if CRLF was found
|
||||
if [ "$CRLF_FOUND" = true ]; then
|
||||
exit 1
|
||||
fi
|
||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@@ -28,3 +28,7 @@ jobs:
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
- name: Run Execution Model Tests
|
||||
run: |
|
||||
python -m pytest tests/inference/test_execution.py
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "5"
|
||||
default: "2"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -53,8 +53,6 @@ jobs:
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
|
||||
26
README.md
26
README.md
@@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Image Models
|
||||
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||
- SD1.x, SD2.x,
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||
@@ -69,7 +69,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
@@ -77,7 +76,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@@ -85,10 +83,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
|
||||
- Safe loading of ckpt, pt, pth, etc.. files.
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
@@ -99,10 +96,12 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
@@ -111,7 +110,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
@@ -179,6 +178,10 @@ If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
## Jupyter Notebook
|
||||
|
||||
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
|
||||
|
||||
|
||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
||||
|
||||
@@ -240,7 +243,7 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@@ -294,13 +297,6 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
|
||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||
3. Launch ComfyUI by running `python main.py`
|
||||
|
||||
#### Iluvatar Corex
|
||||
|
||||
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
||||
@@ -29,48 +29,18 @@ def frontend_install_warning_message():
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
""".strip()
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||
return bool(re.match(pattern, version))
|
||||
|
||||
def get_installed_frontend_version():
|
||||
"""Get the currently installed frontend package version."""
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
return frontend_version_str
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-frontend-package=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = get_installed_frontend_version()
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
required_frontend_str = get_required_frontend_version()
|
||||
required_frontend = parse_version(required_frontend_str)
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
@@ -198,11 +168,6 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
class FrontendManager:
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def get_required_frontend_version(cls) -> str:
|
||||
"""Get the required frontend package version."""
|
||||
return get_required_frontend_version()
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
|
||||
@@ -49,8 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
@@ -145,7 +144,6 @@ class PerformanceFeature(enum.Enum):
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
|
||||
@@ -1,10 +1,55 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
import comfy.ops
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * torch.nn.functional.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GatedCrossAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
|
||||
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
|
||||
# Codebase ref: https://github.com/scxue/SA-Solver
|
||||
|
||||
import math
|
||||
from typing import Union, Callable
|
||||
import torch
|
||||
|
||||
|
||||
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
|
||||
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
|
||||
|
||||
Integral of exp((1 + tau^2) * x) * x^p dx
|
||||
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
|
||||
with base case p=0 where integral equals product_terms[0].
|
||||
|
||||
where
|
||||
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
|
||||
|
||||
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
|
||||
Return coefficients used by the SA-Solver in data prediction mode.
|
||||
|
||||
Args:
|
||||
s: Start time s.
|
||||
t: End time t.
|
||||
solver_order: Current order of the solver.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
|
||||
Returns:
|
||||
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
|
||||
"""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = t - s
|
||||
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
|
||||
|
||||
# product_terms after factoring out exp((1 + tau^2) * t)
|
||||
# Includes (1 + tau^2) factor from outside the integral
|
||||
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
|
||||
|
||||
# Lower triangular recursive coefficient matrix
|
||||
# Accumulates recursive coefficients based on p / (1 + tau^2)
|
||||
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
|
||||
log_factorial = (p + 1).lgamma()
|
||||
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
|
||||
if tau_t > 0:
|
||||
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
|
||||
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
|
||||
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
|
||||
|
||||
return recursive_coeff_mat @ product_terms_factored
|
||||
|
||||
|
||||
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = lambda_t - lambda_s
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
if is_corrector_step:
|
||||
# Simplified 1-step (order-2) corrector
|
||||
b_1 = alpha_t * (0.5 * tau_mul * h)
|
||||
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
|
||||
else:
|
||||
# Simplified 2-step predictor
|
||||
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
|
||||
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
|
||||
return torch.stack([b_2, b_1])
|
||||
|
||||
|
||||
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
|
||||
|
||||
The solver order corresponds to the number of input lambdas (half-logSNR points).
|
||||
|
||||
Args:
|
||||
sigma_next: Sigma at end time t.
|
||||
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
|
||||
lambda_s: Lambda at start time s.
|
||||
lambda_t: Lambda at end time t.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
simple_order_2: Whether to enable the simple order-2 scheme.
|
||||
is_corrector_step: Flag for corrector step in simple order-2 mode.
|
||||
|
||||
Returns:
|
||||
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
|
||||
"""
|
||||
num_timesteps = curr_lambdas.shape[0]
|
||||
|
||||
if simple_order_2 and num_timesteps == 2:
|
||||
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
|
||||
|
||||
# Compute coefficients by solving a linear system from Lagrange basis interpolation
|
||||
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
|
||||
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
|
||||
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
|
||||
|
||||
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
|
||||
# = sigma_t * exp(lambda_t) = alpha_t
|
||||
# exp((1 + tau^2) * lambda_t) is extracted from the integral
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
return alpha_t * lagrange_integrals
|
||||
|
||||
|
||||
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
|
||||
"""Return a function that controls the stochasticity of SA-Solver.
|
||||
|
||||
When eta = 0, SA-Solver runs as ODE. The official approach uses
|
||||
time t to determine the SDE interval, while here we use sigma instead.
|
||||
|
||||
See:
|
||||
https://github.com/scxue/SA-Solver/blob/main/README.md
|
||||
"""
|
||||
|
||||
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
|
||||
if eta <= 0:
|
||||
return 0.0 # ODE
|
||||
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.item()
|
||||
return eta if start_sigma >= sigma >= end_sigma else 0.0
|
||||
|
||||
return tau_func
|
||||
@@ -9,7 +9,6 @@ from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
@@ -413,13 +412,9 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@@ -1072,9 +1067,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||
@@ -1092,7 +1085,6 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
@@ -1116,9 +1108,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
h_n = (t_next - t_cur)
|
||||
@@ -1158,7 +1148,6 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
@torch.no_grad()
|
||||
@@ -1209,22 +1198,39 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps (CFG++)."""
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
|
||||
uncond_denoised = None
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
@@ -1233,33 +1239,15 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
||||
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
||||
|
||||
# DDIM stochastic sampling
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
||||
sigma_down = alpha_t * sigma_down
|
||||
|
||||
# Euler method
|
||||
x = alpha_t * denoised + sigma_down * d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Euler method steps (CFG++)."""
|
||||
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
@@ -1416,7 +1404,6 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
|
||||
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
@@ -1443,19 +1430,19 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
if i == 0:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
x = x + d * dt
|
||||
|
||||
if i >= 1:
|
||||
# Gradient estimation
|
||||
else:
|
||||
# Gradient estimation
|
||||
if cfg_pp:
|
||||
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||
else:
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
@@ -1649,113 +1636,3 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
|
||||
|
||||
if tau_func is None:
|
||||
# Use default interval for stochastic sampling
|
||||
start_sigma = model_sampling.percent_to_sigma(0.2)
|
||||
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
|
||||
|
||||
max_used_order = max(predictor_order, corrector_order)
|
||||
x_pred = x # x: current state, x_pred: predicted next state
|
||||
|
||||
h = 0.0
|
||||
tau_t = 0.0
|
||||
noise = 0.0
|
||||
pred_list = []
|
||||
|
||||
# Lower order near the end to improve stability
|
||||
lower_order_to_end = sigmas[-1].item() == 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
# Evaluation
|
||||
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
pred_list.append(denoised)
|
||||
pred_list = pred_list[-max_used_order:]
|
||||
|
||||
predictor_order_used = min(predictor_order, len(pred_list))
|
||||
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
|
||||
corrector_order_used = 0
|
||||
else:
|
||||
corrector_order_used = min(corrector_order, len(pred_list))
|
||||
|
||||
if lower_order_to_end:
|
||||
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
|
||||
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
|
||||
|
||||
# Corrector
|
||||
if corrector_order_used == 0:
|
||||
# Update by the predicted state
|
||||
x = x_pred
|
||||
else:
|
||||
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i],
|
||||
curr_lambdas,
|
||||
lambdas[i - 1],
|
||||
lambdas[i],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=True,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
|
||||
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
# The noise from the previous predictor step
|
||||
x = x + noise
|
||||
|
||||
if use_pece:
|
||||
# Evaluate the corrected state
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
pred_list[-1] = denoised
|
||||
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i + 1],
|
||||
curr_lambdas,
|
||||
lambdas[i],
|
||||
lambdas[i + 1],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=False,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
|
||||
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
h = lambdas[i + 1] - lambdas[i]
|
||||
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
||||
@@ -457,82 +457,6 @@ class Wan21(LatentFormat):
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class Wan22(Wan21):
|
||||
latent_channels = 48
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.0119, 0.0103, 0.0046],
|
||||
[-0.1062, -0.0504, 0.0165],
|
||||
[ 0.0140, 0.0409, 0.0491],
|
||||
[-0.0813, -0.0677, 0.0607],
|
||||
[ 0.0656, 0.0851, 0.0808],
|
||||
[ 0.0264, 0.0463, 0.0912],
|
||||
[ 0.0295, 0.0326, 0.0590],
|
||||
[-0.0244, -0.0270, 0.0025],
|
||||
[ 0.0443, -0.0102, 0.0288],
|
||||
[-0.0465, -0.0090, -0.0205],
|
||||
[ 0.0359, 0.0236, 0.0082],
|
||||
[-0.0776, 0.0854, 0.1048],
|
||||
[ 0.0564, 0.0264, 0.0561],
|
||||
[ 0.0006, 0.0594, 0.0418],
|
||||
[-0.0319, -0.0542, -0.0637],
|
||||
[-0.0268, 0.0024, 0.0260],
|
||||
[ 0.0539, 0.0265, 0.0358],
|
||||
[-0.0359, -0.0312, -0.0287],
|
||||
[-0.0285, -0.1032, -0.1237],
|
||||
[ 0.1041, 0.0537, 0.0622],
|
||||
[-0.0086, -0.0374, -0.0051],
|
||||
[ 0.0390, 0.0670, 0.2863],
|
||||
[ 0.0069, 0.0144, 0.0082],
|
||||
[ 0.0006, -0.0167, 0.0079],
|
||||
[ 0.0313, -0.0574, -0.0232],
|
||||
[-0.1454, -0.0902, -0.0481],
|
||||
[ 0.0714, 0.0827, 0.0447],
|
||||
[-0.0304, -0.0574, -0.0196],
|
||||
[ 0.0401, 0.0384, 0.0204],
|
||||
[-0.0758, -0.0297, -0.0014],
|
||||
[ 0.0568, 0.1307, 0.1372],
|
||||
[-0.0055, -0.0310, -0.0380],
|
||||
[ 0.0239, -0.0305, 0.0325],
|
||||
[-0.0663, -0.0673, -0.0140],
|
||||
[-0.0416, -0.0047, -0.0023],
|
||||
[ 0.0166, 0.0112, -0.0093],
|
||||
[-0.0211, 0.0011, 0.0331],
|
||||
[ 0.1833, 0.1466, 0.2250],
|
||||
[-0.0368, 0.0370, 0.0295],
|
||||
[-0.3441, -0.3543, -0.2008],
|
||||
[-0.0479, -0.0489, -0.0420],
|
||||
[-0.0660, -0.0153, 0.0800],
|
||||
[-0.0101, 0.0068, 0.0156],
|
||||
[-0.0690, -0.0452, -0.0927],
|
||||
[-0.0145, 0.0041, 0.0015],
|
||||
[ 0.0421, 0.0451, 0.0373],
|
||||
[ 0.0504, -0.0483, -0.0356],
|
||||
[-0.0837, 0.0168, 0.0055]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@@ -254,12 +254,13 @@ class Chroma(nn.Module):
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
@@ -267,4 +268,4 @@ class Chroma(nn.Module):
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
@@ -58,8 +58,7 @@ def is_odd(n: int) -> bool:
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=self.timestep_conditioning,
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return torch.nn.functional.silu(x)
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -1,256 +1,256 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
SizeEmbedder,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||
indexing='ij'
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# For fixed sin-cos embedding:
|
||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||
# self.base_size = input_size // self.patch_size
|
||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) conditioning
|
||||
ar: (N, 1): aspect ratio
|
||||
cs: (N ,2) size conditioning for height/width
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
c_res = (H + W) // 2
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||
self.hidden_size,
|
||||
h=(H // self.patch_size),
|
||||
w=(W // self.patch_size),
|
||||
pe_interpolation=pe_interpolation,
|
||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||
|
||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||
bs = x.shape[0]
|
||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = None
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Fallback for missing microconds
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||
|
||||
## only return EPS
|
||||
if self.pred_sigma:
|
||||
return out[:, :self.in_channels]
|
||||
return out
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = h // self.patch_size
|
||||
w = w // self.patch_size
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
SizeEmbedder,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||
indexing='ij'
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# For fixed sin-cos embedding:
|
||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||
# self.base_size = input_size // self.patch_size
|
||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) conditioning
|
||||
ar: (N, 1): aspect ratio
|
||||
cs: (N ,2) size conditioning for height/width
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
c_res = (H + W) // 2
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||
self.hidden_size,
|
||||
h=(H // self.patch_size),
|
||||
w=(W // self.patch_size),
|
||||
pe_interpolation=pe_interpolation,
|
||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||
|
||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||
bs = x.shape[0]
|
||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = None
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Fallback for missing microconds
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||
|
||||
## only return EPS
|
||||
if self.pred_sigma:
|
||||
return out[:, :self.in_channels]
|
||||
return out
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = h // self.patch_size
|
||||
w = w // self.patch_size
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
@@ -146,15 +146,6 @@ WAN_CROSSATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
def repeat_e(e, x):
|
||||
repeats = 1
|
||||
if e.shape[1] > 1:
|
||||
repeats = x.shape[1] // e.shape[1]
|
||||
if repeats == 1:
|
||||
return e
|
||||
return torch.repeat_interleave(e, repeats, dim=1)
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@@ -211,23 +202,20 @@ class WanAttentionBlock(nn.Module):
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
|
||||
self.norm1(x) * (1 + e[1]) + e[0],
|
||||
freqs)
|
||||
|
||||
x = x + y * repeat_e(e[2], x)
|
||||
x = x + y * e[2]
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
|
||||
x = x + y * repeat_e(e[5], x)
|
||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
||||
x = x + y * e[5]
|
||||
return x
|
||||
|
||||
|
||||
@@ -337,12 +325,8 @@ class Head(nn.Module):
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
if e.ndim < 3:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||
|
||||
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
return x
|
||||
|
||||
|
||||
@@ -522,9 +506,8 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
@@ -769,7 +752,8 @@ class CameraWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x = x + self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
@@ -24,17 +24,12 @@ class CausalConv3d(ops.Conv3d):
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
@@ -57,6 +52,15 @@ class RMS_norm(nn.Module):
|
||||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
@@ -69,11 +73,11 @@ class Resample(nn.Module):
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
@@ -153,6 +157,29 @@ class Resample(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
@@ -171,7 +198,7 @@ class ResidualBlock(nn.Module):
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
old_x = x
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@@ -183,12 +210,12 @@ class ResidualBlock(nn.Module):
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + self.shortcut(old_x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
@@ -467,6 +494,12 @@ class WanVAE(nn.Module):
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
@@ -512,6 +545,18 @@ class WanVAE(nn.Module):
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
|
||||
@@ -1,726 +0,0 @@
|
||||
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in (
|
||||
"none",
|
||||
"upsample2d",
|
||||
"upsample3d",
|
||||
"downsample2d",
|
||||
"downsample3d",
|
||||
)
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
ops.Conv2d(dim, dim, 3, padding=1),
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
ops.Conv2d(dim, dim, 3, padding=1),
|
||||
# ops.Conv2d(dim, dim//2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
elif mode == "downsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == "downsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == "upsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = "Rep"
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
||||
feat_cache[idx] != "Rep"):
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
||||
feat_cache[idx] == "Rep"):
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
if feat_cache[idx] == "Rep":
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
||||
)
|
||||
self.shortcut = (
|
||||
CausalConv3d(in_dim, out_dim, 1)
|
||||
if in_dim != out_dim else nn.Identity())
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + self.shortcut(old_x)
|
||||
|
||||
|
||||
def patchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||
q=patch_size,
|
||||
r=patch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
x = rearrange(
|
||||
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||
elif x.dim() == 5:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||
q=patch_size,
|
||||
r=patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||
pad = (0, 0, 0, 0, pad_t, 0)
|
||||
x = F.pad(x, pad)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
T // self.factor_t,
|
||||
self.factor_t,
|
||||
H // self.factor_s,
|
||||
self.factor_s,
|
||||
W // self.factor_s,
|
||||
self.factor_s,
|
||||
)
|
||||
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||
x = x.view(
|
||||
B,
|
||||
C * self.factor,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.view(
|
||||
B,
|
||||
self.out_channels,
|
||||
self.group_size,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.mean(dim=2)
|
||||
return x
|
||||
|
||||
|
||||
class DupUp3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert out_channels * self.factor % in_channels == 0
|
||||
self.repeats = out_channels * self.factor // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||
x = x.repeat_interleave(self.repeats, dim=1)
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
self.factor_t,
|
||||
self.factor_s,
|
||||
self.factor_s,
|
||||
x.size(2),
|
||||
x.size(3),
|
||||
x.size(4),
|
||||
)
|
||||
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
x.size(2) * self.factor_t,
|
||||
x.size(4) * self.factor_s,
|
||||
x.size(6) * self.factor_s,
|
||||
)
|
||||
if first_chunk:
|
||||
x = x[:, :, self.factor_t - 1:, :, :]
|
||||
return x
|
||||
|
||||
|
||||
class Down_ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout,
|
||||
mult,
|
||||
temperal_downsample=False,
|
||||
down_flag=False):
|
||||
super().__init__()
|
||||
|
||||
# Shortcut path with downsample
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path with residual blocks and downsample
|
||||
downsamples = []
|
||||
for _ in range(mult):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
|
||||
# Add the final downsample block
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout,
|
||||
mult,
|
||||
temperal_upsample=False,
|
||||
up_flag=False):
|
||||
super().__init__()
|
||||
# Shortcut path with upsample
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2 if up_flag else 1,
|
||||
)
|
||||
else:
|
||||
self.avg_shortcut = None
|
||||
|
||||
# Main path with residual blocks and upsample
|
||||
upsamples = []
|
||||
for _ in range(mult):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
|
||||
# Add the final upsample block
|
||||
if up_flag:
|
||||
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
x_main = x
|
||||
for module in self.upsamples:
|
||||
x_main = module(x_main, feat_cache, feat_idx)
|
||||
if self.avg_shortcut is not None:
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
return x_main + x_shortcut
|
||||
else:
|
||||
return x_main
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_down_flag = (
|
||||
temperal_downsample[i]
|
||||
if i < len(temperal_downsample) else False)
|
||||
downsamples.append(
|
||||
Down_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
dropout=dropout,
|
||||
mult=num_res_blocks,
|
||||
temperal_downsample=t_down_flag,
|
||||
down_flag=i != len(dim_mult) - 1,
|
||||
))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout),
|
||||
)
|
||||
|
||||
# # output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout),
|
||||
)
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_up_flag = temperal_upsample[i] if i < len(
|
||||
temperal_upsample) else False
|
||||
upsamples.append(
|
||||
Up_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
dropout=dropout,
|
||||
mult=num_res_blocks + 1,
|
||||
temperal_upsample=t_up_flag,
|
||||
up_flag=i != len(dim_mult) - 1,
|
||||
))
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(out_dim, 12, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat(
|
||||
[
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device),
|
||||
cache_x,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=160,
|
||||
dec_dim=256,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(
|
||||
dim,
|
||||
z_dim * 2,
|
||||
dim_mult,
|
||||
num_res_blocks,
|
||||
attn_scales,
|
||||
self.temperal_downsample,
|
||||
dropout,
|
||||
)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(
|
||||
dec_dim,
|
||||
z_dim,
|
||||
dim_mult,
|
||||
num_res_blocks,
|
||||
attn_scales,
|
||||
self.temperal_upsample,
|
||||
dropout,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
self.clear_cache()
|
||||
x = patchify(x, patch_size=2)
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
self.clear_cache()
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
first_chunk=True,
|
||||
)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = unpatchify(out, patch_size=2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
@@ -1097,9 +1097,8 @@ class WAN21(BaseModel):
|
||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if extra_channels != image.shape[1] + 4:
|
||||
if not self.image_to_video or extra_channels == image.shape[1]:
|
||||
return image
|
||||
if not self.image_to_video or extra_channels == image.shape[1]:
|
||||
return image
|
||||
|
||||
if image.shape[1] > (extra_channels - 4):
|
||||
image = image[:, :(extra_channels - 4)]
|
||||
@@ -1183,31 +1182,6 @@ class WAN21_Camera(WAN21):
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class WAN22(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
if denoise_mask is None:
|
||||
return timestep
|
||||
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
||||
return temp_ts
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
@@ -346,9 +346,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
||||
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
|
||||
dit_config["dim"] = dim
|
||||
dit_config["out_dim"] = out_dim
|
||||
dit_config["num_heads"] = dim // 128
|
||||
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
@@ -101,7 +101,7 @@ if args.directml is not None:
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
import intel_extension_for_pytorch as ipex
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = xpu_available or torch.xpu.is_available()
|
||||
except:
|
||||
@@ -128,11 +128,6 @@ try:
|
||||
except:
|
||||
mlu_available = False
|
||||
|
||||
try:
|
||||
ixuca_available = hasattr(torch, "corex")
|
||||
except:
|
||||
ixuca_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@@ -156,12 +151,6 @@ def is_mlu():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_ixuca():
|
||||
global ixuca_available
|
||||
if ixuca_available:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@@ -197,9 +186,8 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_xpu
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@@ -300,7 +288,7 @@ try:
|
||||
if torch_version_numeric[0] >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
|
||||
if is_intel_xpu() or is_ascend_npu() or is_mlu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
@@ -319,10 +307,7 @@ try:
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 8):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
@@ -392,8 +377,6 @@ def get_torch_device_name(device):
|
||||
except:
|
||||
allocator_backend = ""
|
||||
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
||||
elif device.type == "xpu":
|
||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||
else:
|
||||
return "{}".format(device.type)
|
||||
elif is_intel_xpu():
|
||||
@@ -529,8 +512,6 @@ WINDOWS = any(platform.win32_ver())
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@@ -895,7 +876,6 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
return d
|
||||
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
@@ -949,7 +929,7 @@ def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
return False
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_enabled:
|
||||
@@ -988,8 +968,6 @@ def get_offload_stream(device):
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
elif is_device_xpu(device):
|
||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_cuda(device):
|
||||
@@ -1001,15 +979,6 @@ def get_offload_stream(device):
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_xpu(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
def sync_stream(device, stream):
|
||||
@@ -1017,8 +986,6 @@ def sync_stream(device, stream):
|
||||
return
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.current_stream().wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
@@ -1060,8 +1027,6 @@ def xformers_enabled():
|
||||
return False
|
||||
if is_mlu():
|
||||
return False
|
||||
if is_ixuca():
|
||||
return False
|
||||
if directml_enabled:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
@@ -1097,8 +1062,6 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
if is_amd():
|
||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
if is_ixuca():
|
||||
return True
|
||||
return False
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
@@ -1129,8 +1092,8 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
@@ -1179,9 +1142,6 @@ def is_device_cpu(device):
|
||||
def is_device_mps(device):
|
||||
return is_device_type(device, 'mps')
|
||||
|
||||
def is_device_xpu(device):
|
||||
return is_device_type(device, 'xpu')
|
||||
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
@@ -1213,10 +1173,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
return True
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@@ -1224,9 +1181,6 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_mlu():
|
||||
return True
|
||||
|
||||
if is_ixuca():
|
||||
return True
|
||||
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
@@ -1282,15 +1236,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 6):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_ixuca():
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_amd():
|
||||
|
||||
@@ -379,9 +379,6 @@ class ModelPatcher:
|
||||
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
|
||||
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
|
||||
@@ -336,12 +336,9 @@ class fp8_ops(manual_cast):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
@@ -373,11 +373,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
uncond_ = uncond
|
||||
|
||||
conds = [cond, uncond_]
|
||||
if "sampler_calc_cond_batch_function" in model_options:
|
||||
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||
out = model_options["sampler_calc_cond_batch_function"](args)
|
||||
else:
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||
@@ -720,7 +716,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
||||
47
comfy/sd.py
47
comfy/sd.py
@@ -14,12 +14,10 @@ import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
|
||||
import comfy.utils
|
||||
|
||||
@@ -421,30 +419,17 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.middle.0.residual.0.gamma" in sd:
|
||||
if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 48
|
||||
ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||
else: # Wan 2.1 VAE
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
self.latent_dim = 1
|
||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
||||
@@ -992,12 +977,6 @@ def load_gligen(ckpt_path):
|
||||
model = model.half()
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def model_detection_error_hint(path, state_dict):
|
||||
filename = os.path.basename(path)
|
||||
if 'lora' in filename.lower():
|
||||
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
|
||||
return ""
|
||||
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||
@@ -1026,7 +1005,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||
@@ -1198,7 +1177,7 @@ def load_diffusion_model(unet_path, model_options={}):
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"single_word": false
|
||||
},
|
||||
"errors": "replace",
|
||||
"model_max_length": 8192,
|
||||
"model_max_length": 77,
|
||||
"name_or_path": "openai/clip-vit-large-patch14",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"special_tokens_map_file": "./special_tokens_map.json",
|
||||
|
||||
@@ -1059,19 +1059,6 @@ class WAN21_Vace(WAN21_T2V):
|
||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
"out_dim": 48,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Wan22
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@@ -1227,9 +1214,9 @@ class Omnigen2(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
import os
|
||||
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy.sd1_clip import gen_empty_tokens
|
||||
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
||||
# PixArt expects the negative to be all pad tokens
|
||||
special_tokens = special_tokens.copy()
|
||||
special_tokens.pop("end")
|
||||
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
||||
|
||||
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||
|
||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
import os
|
||||
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy.sd1_clip import gen_empty_tokens
|
||||
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
||||
# PixArt expects the negative to be all pad tokens
|
||||
special_tokens = special_tokens.copy()
|
||||
special_tokens.pop("end")
|
||||
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
||||
|
||||
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||
|
||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
|
||||
@@ -31,7 +31,6 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
@@ -59,10 +58,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
sd[k] = f.get_tensor(k)
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
@@ -81,7 +77,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
@@ -698,26 +693,6 @@ def resize_to_batch_size(tensor, batch_size):
|
||||
|
||||
return output
|
||||
|
||||
def resize_list_to_batch_size(l, batch_size):
|
||||
in_batch_size = len(l)
|
||||
if in_batch_size == batch_size or in_batch_size == 0:
|
||||
return l
|
||||
|
||||
if batch_size <= 1:
|
||||
return l[:batch_size]
|
||||
|
||||
output = []
|
||||
if batch_size < in_batch_size:
|
||||
scale = (in_batch_size - 1) / (batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
output.append(l[min(round(i * scale), in_batch_size - 1)])
|
||||
else:
|
||||
scale = in_batch_size / batch_size
|
||||
for i in range(batch_size):
|
||||
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
|
||||
|
||||
return output
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
|
||||
@@ -15,20 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [
|
||||
OFTAdapter,
|
||||
BOFTAdapter,
|
||||
]
|
||||
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
||||
"LoRA": LoRAAdapter,
|
||||
"LoHa": LoHaAdapter,
|
||||
"LoKr": LoKrAdapter,
|
||||
"OFT": OFTAdapter,
|
||||
## We disable not implemented algo for now
|
||||
# "GLoRA": GLoRAAdapter,
|
||||
# "BOFT": BOFTAdapter,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WeightAdapterBase",
|
||||
"WeightAdapterTrainBase",
|
||||
"adapters",
|
||||
"adapter_maps",
|
||||
"adapters"
|
||||
] + [a.__name__ for a in adapters]
|
||||
|
||||
@@ -133,43 +133,3 @@ def tucker_weight_from_conv(up, down, mid):
|
||||
def tucker_weight(wa, wb, t):
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||
|
||||
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||
"""
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||
"""
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
@@ -3,120 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||
return diff_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
temp = grad_out * (w2u @ w2d)
|
||||
grad_w1u = temp @ w1d.T
|
||||
grad_w1d = w1u.T @ temp
|
||||
|
||||
temp = grad_out * (w1u @ w1d)
|
||||
grad_w2u = temp @ w2d.T
|
||||
grad_w2d = w2u.T @ temp
|
||||
|
||||
del temp
|
||||
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class HadaWeightTucker(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||
|
||||
return rebuild1 * rebuild2 * scale
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||
del grad_temp
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||
del grad_temp
|
||||
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class LohaDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||
|
||||
self.use_tucker = False
|
||||
if t1 is not None and t2 is not None:
|
||||
self.use_tucker = True
|
||||
self.hada_t1 = torch.nn.Parameter(t1)
|
||||
self.hada_t2 = torch.nn.Parameter(t2)
|
||||
else:
|
||||
# Keep the attributes for consistent access
|
||||
self.hada_t1 = None
|
||||
self.hada_t2 = None
|
||||
|
||||
# Store rank and non-trainable alpha
|
||||
self.rank = w1b.shape[0]
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
|
||||
scale = self.alpha / self.rank
|
||||
if self.use_tucker:
|
||||
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
else:
|
||||
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
# Add the scaled difference to the original weight
|
||||
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class LoHaAdapter(WeightAdapterBase):
|
||||
@@ -126,25 +13,6 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat1, 0.1)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat3, 0.1)
|
||||
torch.nn.init.normal_(mat4, 0.01)
|
||||
return LohaDiff(
|
||||
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LohaDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
||||
@@ -3,77 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
factorization,
|
||||
)
|
||||
|
||||
|
||||
class LokrDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||
self.use_tucker = False
|
||||
if lokr_w1_a is not None:
|
||||
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
||||
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
||||
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
||||
self.w1_rebuild = True
|
||||
self.ranka = rank_a
|
||||
|
||||
if lokr_w2_a is not None:
|
||||
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
||||
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
||||
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
||||
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
||||
if lokr_t2 is not None:
|
||||
self.use_tucker = True
|
||||
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
||||
self.w2_rebuild = True
|
||||
self.rankb = rank_b
|
||||
|
||||
if lokr_w1 is not None:
|
||||
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
||||
self.w1_rebuild = False
|
||||
|
||||
if lokr_w2 is not None:
|
||||
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
||||
self.w2_rebuild = False
|
||||
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
@property
|
||||
def w1(self):
|
||||
if self.w1_rebuild:
|
||||
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
||||
else:
|
||||
return self.lokr_w1
|
||||
|
||||
@property
|
||||
def w2(self):
|
||||
if self.w2_rebuild:
|
||||
if self.use_tucker:
|
||||
w2 = torch.einsum(
|
||||
'i j k l, j r, i p -> p r k l',
|
||||
self.lokr_t2,
|
||||
self.lokr_w2_b,
|
||||
self.lokr_w2_a
|
||||
)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
return w2 * (self.alpha / self.rankb)
|
||||
else:
|
||||
return self.lokr_w2
|
||||
|
||||
def __call__(self, w):
|
||||
diff = torch.kron(self.w1, self.w2)
|
||||
return w + diff.reshape(w.shape).to(w)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class LoKrAdapter(WeightAdapterBase):
|
||||
@@ -83,20 +13,6 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
out1, out2 = factorization(out_dim, rank)
|
||||
in1, in2 = factorization(in_dim, rank)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||
torch.nn.init.constant_(mat1, 0.0)
|
||||
return LokrDiff(
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
||||
@@ -3,58 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||
|
||||
|
||||
class OFTDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
blocks, rescale, alpha, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
self.oft_blocks = torch.nn.Parameter(blocks)
|
||||
if rescale is not None:
|
||||
self.rescale = torch.nn.Parameter(rescale)
|
||||
self.rescaled = True
|
||||
else:
|
||||
self.rescaled = False
|
||||
self.block_num, self.block_size, _ = blocks.shape
|
||||
self.constraint = float(alpha)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||
|
||||
## generate r
|
||||
# for Q = -Q^T
|
||||
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if self.constraint:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > self.constraint:
|
||||
normed_q = q * self.constraint / q_norm
|
||||
# use float() to prevent unsupported type
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
|
||||
## Apply chunked matmul on weight
|
||||
_, *shape = w.shape
|
||||
org_weight = w.to(dtype=r.dtype)
|
||||
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
||||
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
||||
weight = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
r,
|
||||
org_weight,
|
||||
).flatten(0, 1)
|
||||
if self.rescaled:
|
||||
weight = self.rescale * weight
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
|
||||
|
||||
class OFTAdapter(WeightAdapterBase):
|
||||
@@ -64,18 +13,6 @@ class OFTAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
block_size, block_num = factorization(out_dim, rank)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||
return OFTDiff(
|
||||
(block, None, alpha, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return OFTDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
@@ -123,8 +60,6 @@ class OFTAdapter(WeightAdapterBase):
|
||||
blocks = v[0]
|
||||
rescale = v[1]
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
Feature flags module for ComfyUI WebSocket protocol negotiation.
|
||||
|
||||
This module handles capability negotiation between frontend and backend,
|
||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# Default server capabilities
|
||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
}
|
||||
|
||||
|
||||
def get_connection_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str,
|
||||
default: Any = False
|
||||
) -> Any:
|
||||
"""
|
||||
Get a feature flag value for a specific connection.
|
||||
|
||||
Args:
|
||||
sockets_metadata: Dictionary of socket metadata
|
||||
sid: Session ID of the connection
|
||||
feature_name: Name of the feature to check
|
||||
default: Default value if feature not found
|
||||
|
||||
Returns:
|
||||
Feature value or default if not found
|
||||
"""
|
||||
if sid not in sockets_metadata:
|
||||
return default
|
||||
|
||||
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
|
||||
|
||||
|
||||
def supports_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a connection supports a specific feature.
|
||||
|
||||
Args:
|
||||
sockets_metadata: Dictionary of socket metadata
|
||||
sid: Session ID of the connection
|
||||
feature_name: Name of the feature to check
|
||||
|
||||
Returns:
|
||||
Boolean indicating if feature is supported
|
||||
"""
|
||||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||
|
||||
|
||||
def get_server_features() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the server's feature flags.
|
||||
|
||||
Returns:
|
||||
Dictionary of server feature flags
|
||||
"""
|
||||
return SERVER_FEATURE_FLAGS.copy()
|
||||
@@ -1,86 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to generate .pyi stub files for the synchronous API wrappers.
|
||||
This allows generating stubs without running the full ComfyUI application.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
# Add ComfyUI to path so we can import modules
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
|
||||
from comfy_api.version_list import supported_versions
|
||||
|
||||
|
||||
def generate_stubs_for_module(module_name: str) -> None:
|
||||
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Check if module has ComfyAPISync (the sync wrapper)
|
||||
if hasattr(module, "ComfyAPISync"):
|
||||
# Module already has a sync class
|
||||
api_class = getattr(module, "ComfyAPI", None)
|
||||
sync_class = getattr(module, "ComfyAPISync")
|
||||
|
||||
if api_class:
|
||||
# Generate the stub file
|
||||
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||
logging.info(f"Generated stub file for {module_name}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
|
||||
)
|
||||
|
||||
elif hasattr(module, "ComfyAPI"):
|
||||
# Module only has async API, need to create sync wrapper first
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
|
||||
api_class = getattr(module, "ComfyAPI")
|
||||
sync_class = create_sync_class(api_class)
|
||||
|
||||
# Generate the stub file
|
||||
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||
logging.info(f"Generated stub file for {module_name}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to generate stub for {module_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to generate all API stub files."""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logging.info("Starting stub generation...")
|
||||
|
||||
# Dynamically get module names from supported_versions
|
||||
api_modules = []
|
||||
for api_class in supported_versions:
|
||||
# Extract module name from the class
|
||||
module_name = api_class.__module__
|
||||
if module_name not in api_modules:
|
||||
api_modules.append(module_name)
|
||||
|
||||
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
|
||||
|
||||
# Generate stubs for each module
|
||||
for module_name in api_modules:
|
||||
generate_stubs_for_module(module_name)
|
||||
|
||||
logging.info("Stub generation complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,16 +1,8 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input import (
|
||||
ImageInput,
|
||||
AudioInput,
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
)
|
||||
from .basic_types import ImageInput, AudioInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
]
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input.basic_types import (
|
||||
ImageInput,
|
||||
AudioInput,
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
)
|
||||
import torch
|
||||
from typing import TypedDict
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,55 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input.video_types import VideoInput
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
"VideoInput",
|
||||
]
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
||||
|
||||
@@ -1,2 +1,303 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input_impl.video_types import * # noqa: F403
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import AudioInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
# Internal infrastructure for ComfyAPI
|
||||
from .api_registry import (
|
||||
ComfyAPIBase as ComfyAPIBase,
|
||||
ComfyAPIWithVersion as ComfyAPIWithVersion,
|
||||
register_versions as register_versions,
|
||||
get_all_versions as get_all_versions,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||
|
||||
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||
"""
|
||||
if base is None:
|
||||
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||
base = cls.GET_BASE_CLASS()
|
||||
base_attr = getattr(base, name, None)
|
||||
if base_attr is None:
|
||||
return None
|
||||
base_func = base_attr.__func__
|
||||
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
|
||||
if c is base: # reached the placeholder – we're done
|
||||
break
|
||||
if name in c.__dict__: # first class that *defines* the attr
|
||||
func = getattr(c, name).__func__
|
||||
if func is not base_func: # real override
|
||||
return getattr(cls, name) # bound to *cls*
|
||||
return None
|
||||
|
||||
|
||||
class _ComfyNodeInternal:
|
||||
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V1(cls):
|
||||
...
|
||||
|
||||
|
||||
class _NodeOutputInternal:
|
||||
"""Class that all V3-based APIs inherit from for NodeOutput.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
...
|
||||
|
||||
|
||||
def as_pruned_dict(dataclass_obj):
|
||||
'''Return dict of dataclass object with pruned None values.'''
|
||||
return prune_dict(asdict(dataclass_obj))
|
||||
|
||||
def prune_dict(d: dict):
|
||||
return {k: v for k,v in d.items() if v is not None}
|
||||
|
||||
|
||||
def is_class(obj):
|
||||
'''
|
||||
Returns True if is a class type.
|
||||
Returns False if is a class instance.
|
||||
'''
|
||||
return isinstance(obj, type)
|
||||
|
||||
|
||||
def copy_class(cls: type) -> type:
|
||||
'''
|
||||
Copy a class and its attributes.
|
||||
'''
|
||||
if cls is None:
|
||||
return None
|
||||
cls_dict = {
|
||||
k: v for k, v in cls.__dict__.items()
|
||||
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
|
||||
}
|
||||
# new class
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(cls,),
|
||||
cls_dict
|
||||
)
|
||||
# metadata preservation
|
||||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
return new_cls
|
||||
|
||||
|
||||
class classproperty(object):
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
def __get__(self, obj, owner):
|
||||
return self.f(owner)
|
||||
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def shallow_clone_class(cls, new_name=None):
|
||||
'''
|
||||
Shallow clone a class while preserving super() functionality.
|
||||
'''
|
||||
new_name = new_name or f"{cls.__name__}Clone"
|
||||
# Include the original class in the bases to maintain proper inheritance
|
||||
new_bases = (cls,) + cls.__bases__
|
||||
return type(new_name, new_bases, dict(cls.__dict__))
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def lock_class(cls):
|
||||
'''
|
||||
Lock a class so that its top-levelattributes cannot be modified.
|
||||
'''
|
||||
# Locked instance __setattr__
|
||||
def locked_instance_setattr(self, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
|
||||
)
|
||||
# Locked metaclass
|
||||
class LockedMeta(type(cls)):
|
||||
def __setattr__(cls_, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
|
||||
)
|
||||
# Rebuild class with locked behavior
|
||||
locked_dict = dict(cls.__dict__)
|
||||
locked_dict['__setattr__'] = locked_instance_setattr
|
||||
|
||||
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||
|
||||
|
||||
def make_locked_method_func(type_obj, func, class_clone):
|
||||
"""
|
||||
Returns a function that, when called with **inputs, will execute:
|
||||
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||
|
||||
Supports both synchronous and asynchronous methods.
|
||||
"""
|
||||
locked_class = lock_class(class_clone)
|
||||
method = getattr(type_obj, func).__func__
|
||||
|
||||
# Check if the original method is async
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
async def wrapped_async_func(**inputs):
|
||||
return await method(locked_class, **inputs)
|
||||
return wrapped_async_func
|
||||
else:
|
||||
def wrapped_func(**inputs):
|
||||
return method(locked_class, **inputs)
|
||||
return wrapped_func
|
||||
@@ -1,39 +0,0 @@
|
||||
from typing import Type, List, NamedTuple
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from packaging import version as packaging_version
|
||||
|
||||
|
||||
class ComfyAPIBase(ProxiedSingleton):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class ComfyAPIWithVersion(NamedTuple):
|
||||
version: str
|
||||
api_class: Type[ComfyAPIBase]
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> packaging_version.Version:
|
||||
"""
|
||||
Parses a version string into a packaging_version.Version object.
|
||||
Raises ValueError if the version string is invalid.
|
||||
"""
|
||||
if version_str == "latest":
|
||||
return packaging_version.parse("9999999.9999999.9999999")
|
||||
return packaging_version.parse(version_str)
|
||||
|
||||
|
||||
registered_versions: List[ComfyAPIWithVersion] = []
|
||||
|
||||
|
||||
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||
versions.sort(key=lambda x: parse_version(x.version))
|
||||
global registered_versions
|
||||
registered_versions = versions
|
||||
|
||||
|
||||
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||
"""
|
||||
Returns a list of all registered ComfyAPI versions.
|
||||
"""
|
||||
return registered_versions
|
||||
@@ -1,987 +0,0 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
"""Tracks types discovered during stub generation for automatic import generation."""
|
||||
|
||||
def __init__(self):
|
||||
self.discovered_types = {} # type_name -> (module, qualname)
|
||||
self.builtin_types = {
|
||||
"Any",
|
||||
"Dict",
|
||||
"List",
|
||||
"Optional",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"Set",
|
||||
"Sequence",
|
||||
"cast",
|
||||
"NamedTuple",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"None",
|
||||
"bytes",
|
||||
"object",
|
||||
"type",
|
||||
"dict",
|
||||
"list",
|
||||
"tuple",
|
||||
"set",
|
||||
}
|
||||
self.already_imported = (
|
||||
set()
|
||||
) # Track types already imported to avoid duplicates
|
||||
|
||||
def track_type(self, annotation):
|
||||
"""Track a type annotation and record its module/import info."""
|
||||
if annotation is None or annotation is type(None):
|
||||
return
|
||||
|
||||
# Skip builtins and typing module types we already import
|
||||
type_name = getattr(annotation, "__name__", None)
|
||||
if type_name and (
|
||||
type_name in self.builtin_types or type_name in self.already_imported
|
||||
):
|
||||
return
|
||||
|
||||
# Get module and qualname
|
||||
module = getattr(annotation, "__module__", None)
|
||||
qualname = getattr(annotation, "__qualname__", type_name or "")
|
||||
|
||||
# Skip types from typing module (they're already imported)
|
||||
if module == "typing":
|
||||
return
|
||||
|
||||
# Skip UnionType and GenericAlias from types module as they're handled specially
|
||||
if module == "types" and type_name in ("UnionType", "GenericAlias"):
|
||||
return
|
||||
|
||||
if module and module not in ["builtins", "__main__"]:
|
||||
# Store the type info
|
||||
if type_name:
|
||||
self.discovered_types[type_name] = (module, qualname)
|
||||
|
||||
def get_imports(self, main_module_name: str) -> list[str]:
|
||||
"""Generate import statements for all discovered types."""
|
||||
imports = []
|
||||
imports_by_module = {}
|
||||
|
||||
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
|
||||
# Skip types from the main module (they're already imported)
|
||||
if main_module_name and module == main_module_name:
|
||||
continue
|
||||
|
||||
if module not in imports_by_module:
|
||||
imports_by_module[module] = []
|
||||
if type_name not in imports_by_module[module]: # Avoid duplicates
|
||||
imports_by_module[module].append(type_name)
|
||||
|
||||
# Generate import statements
|
||||
for module, types in sorted(imports_by_module.items()):
|
||||
if len(types) == 1:
|
||||
imports.append(f"from {module} import {types[0]}")
|
||||
else:
|
||||
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
|
||||
|
||||
return imports
|
||||
|
||||
|
||||
class AsyncToSyncConverter:
|
||||
"""
|
||||
Provides utilities to convert async classes to sync classes with proper type hints.
|
||||
"""
|
||||
|
||||
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
||||
_thread_pool_lock = threading.Lock()
|
||||
_thread_pool_initialized = False
|
||||
|
||||
@classmethod
|
||||
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
|
||||
"""Get or create the shared thread pool with proper thread-safe initialization."""
|
||||
# Fast path - check if already initialized without acquiring lock
|
||||
if cls._thread_pool_initialized:
|
||||
assert cls._thread_pool is not None, "Thread pool should be initialized"
|
||||
return cls._thread_pool
|
||||
|
||||
# Slow path - acquire lock and create pool if needed
|
||||
with cls._thread_pool_lock:
|
||||
if not cls._thread_pool_initialized:
|
||||
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers, thread_name_prefix="async_to_sync_"
|
||||
)
|
||||
cls._thread_pool_initialized = True
|
||||
|
||||
# This should never be None at this point, but add assertion for type checker
|
||||
assert cls._thread_pool is not None
|
||||
return cls._thread_pool
|
||||
|
||||
@classmethod
|
||||
def run_async_in_thread(cls, coro_func, *args, **kwargs):
|
||||
"""
|
||||
Run an async function in a separate thread from the thread pool.
|
||||
Blocks until the async function completes.
|
||||
Properly propagates contextvars between threads and manages event loops.
|
||||
"""
|
||||
# Capture current context - this includes all context variables
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Store the result and any exception that occurs
|
||||
result_container: dict = {"result": None, "exception": None}
|
||||
|
||||
# Function that runs in the thread pool
|
||||
def run_in_thread():
|
||||
# Create new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Create the coroutine within the context
|
||||
async def run_with_context():
|
||||
# The coroutine function might access context variables
|
||||
return await coro_func(*args, **kwargs)
|
||||
|
||||
# Run the coroutine with the captured context
|
||||
# This ensures all context variables are available in the async function
|
||||
result = context.run(loop.run_until_complete, run_with_context())
|
||||
result_container["result"] = result
|
||||
except Exception as e:
|
||||
# Store the exception to re-raise in the calling thread
|
||||
result_container["exception"] = e
|
||||
finally:
|
||||
# Ensure event loop is properly closed to prevent warnings
|
||||
try:
|
||||
# Cancel any remaining tasks
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Run the loop briefly to handle cancellations
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Close the event loop
|
||||
loop.close()
|
||||
|
||||
# Clear the event loop from the thread
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# Submit to thread pool and wait for result
|
||||
thread_pool = cls.get_thread_pool()
|
||||
future = thread_pool.submit(run_in_thread)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Re-raise any exception that occurred in the thread
|
||||
if result_container["exception"] is not None:
|
||||
raise result_container["exception"]
|
||||
|
||||
return result_container["result"]
|
||||
|
||||
@classmethod
|
||||
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a new class with synchronous versions of all async methods.
|
||||
|
||||
Args:
|
||||
async_class: The async class to convert
|
||||
thread_pool_size: Size of thread pool to use
|
||||
|
||||
Returns:
|
||||
A new class with sync versions of all async methods
|
||||
"""
|
||||
sync_class_name = "ComfyAPISyncStub"
|
||||
cls.get_thread_pool(thread_pool_size)
|
||||
|
||||
# Create a proper class with docstrings and proper base classes
|
||||
sync_class_dict = {
|
||||
"__doc__": async_class.__doc__,
|
||||
"__module__": async_class.__module__,
|
||||
"__qualname__": sync_class_name,
|
||||
"__orig_class__": async_class, # Store original class for typing references
|
||||
}
|
||||
|
||||
# Create __init__ method
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._async_instance = async_class(*args, **kwargs)
|
||||
|
||||
# Handle annotated class attributes (like execution: Execution)
|
||||
# Get all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
# For each annotated attribute, check if it needs to be created or wrapped
|
||||
for attr_name, attr_type in all_annotations.items():
|
||||
if hasattr(self._async_instance, attr_name):
|
||||
# Attribute exists on the instance
|
||||
attr = getattr(self._async_instance, attr_name)
|
||||
# Check if this attribute needs a sync wrapper
|
||||
if hasattr(attr, "__class__"):
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
if isinstance(attr, ProxiedSingleton):
|
||||
# Create a sync version of this attribute
|
||||
try:
|
||||
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||
# Create instance of the sync wrapper with the async instance
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = attr
|
||||
setattr(self, attr_name, sync_attr)
|
||||
except Exception:
|
||||
# If we can't create a sync version, keep the original
|
||||
setattr(self, attr_name, attr)
|
||||
else:
|
||||
# Not async, just copy the reference
|
||||
setattr(self, attr_name, attr)
|
||||
else:
|
||||
# Attribute doesn't exist, but is annotated - create it
|
||||
# This handles cases like execution: Execution
|
||||
if isinstance(attr_type, type):
|
||||
# Check if the type is defined as an inner class
|
||||
if hasattr(async_class, attr_type.__name__):
|
||||
inner_class = getattr(async_class, attr_type.__name__)
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
# Create an instance of the inner class
|
||||
try:
|
||||
# For ProxiedSingleton classes, get or create the singleton instance
|
||||
if issubclass(inner_class, ProxiedSingleton):
|
||||
async_instance = inner_class.get_instance()
|
||||
else:
|
||||
async_instance = inner_class()
|
||||
|
||||
# Create sync wrapper
|
||||
sync_attr_class = cls.create_sync_class(inner_class)
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = async_instance
|
||||
setattr(self, attr_name, sync_attr)
|
||||
# Also set on the async instance for consistency
|
||||
setattr(self._async_instance, attr_name, async_instance)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to create instance for {attr_name}: {e}"
|
||||
)
|
||||
|
||||
# Handle other instance attributes that might not be annotated
|
||||
for name, attr in inspect.getmembers(self._async_instance):
|
||||
if name.startswith("_") or hasattr(self, name):
|
||||
continue
|
||||
|
||||
# If attribute is an instance of a class, and that class is defined in the original class
|
||||
# we need to check if it needs a sync wrapper
|
||||
if isinstance(attr, object) and not isinstance(
|
||||
attr, (str, int, float, bool, list, dict, tuple)
|
||||
):
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
if isinstance(attr, ProxiedSingleton):
|
||||
# Create a sync version of this nested class
|
||||
try:
|
||||
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||
# Create instance of the sync wrapper with the async instance
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = attr
|
||||
setattr(self, name, sync_attr)
|
||||
except Exception:
|
||||
# If we can't create a sync version, keep the original
|
||||
setattr(self, name, attr)
|
||||
|
||||
sync_class_dict["__init__"] = __init__
|
||||
|
||||
# Process methods from the async class
|
||||
for name, method in inspect.getmembers(
|
||||
async_class, predicate=inspect.isfunction
|
||||
):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Extract the actual return type from a coroutine
|
||||
if inspect.iscoroutinefunction(method):
|
||||
# Create sync version of async method with proper signature
|
||||
@functools.wraps(method)
|
||||
def sync_method(self, *args, _method_name=name, **kwargs):
|
||||
async_method = getattr(self._async_instance, _method_name)
|
||||
return AsyncToSyncConverter.run_async_in_thread(
|
||||
async_method, *args, **kwargs
|
||||
)
|
||||
|
||||
# Add to the class dict
|
||||
sync_class_dict[name] = sync_method
|
||||
else:
|
||||
# For regular methods, create a proxy method
|
||||
@functools.wraps(method)
|
||||
def proxy_method(self, *args, _method_name=name, **kwargs):
|
||||
method = getattr(self._async_instance, _method_name)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
# Add to the class dict
|
||||
sync_class_dict[name] = proxy_method
|
||||
|
||||
# Handle property access
|
||||
for name, prop in inspect.getmembers(
|
||||
async_class, lambda x: isinstance(x, property)
|
||||
):
|
||||
|
||||
def make_property(name, prop_obj):
|
||||
def getter(self):
|
||||
value = getattr(self._async_instance, name)
|
||||
if inspect.iscoroutinefunction(value):
|
||||
|
||||
def sync_fn(*args, **kwargs):
|
||||
return AsyncToSyncConverter.run_async_in_thread(
|
||||
value, *args, **kwargs
|
||||
)
|
||||
|
||||
return sync_fn
|
||||
return value
|
||||
|
||||
def setter(self, value):
|
||||
setattr(self._async_instance, name, value)
|
||||
|
||||
return property(getter, setter if prop_obj.fset else None)
|
||||
|
||||
sync_class_dict[name] = make_property(name, prop)
|
||||
|
||||
# Create the class
|
||||
sync_class = type(sync_class_name, (object,), sync_class_dict)
|
||||
|
||||
return sync_class
|
||||
|
||||
@classmethod
|
||||
def _format_type_annotation(
|
||||
cls, annotation, type_tracker: Optional[TypeTracker] = None
|
||||
) -> str:
|
||||
"""Convert a type annotation to its string representation for stub files."""
|
||||
if (
|
||||
annotation is inspect.Parameter.empty
|
||||
or annotation is inspect.Signature.empty
|
||||
):
|
||||
return "Any"
|
||||
|
||||
# Handle None type
|
||||
if annotation is type(None):
|
||||
return "None"
|
||||
|
||||
# Track the type if we have a tracker
|
||||
if type_tracker:
|
||||
type_tracker.track_type(annotation)
|
||||
|
||||
# Try using typing.get_origin/get_args for Python 3.8+
|
||||
try:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin is not None:
|
||||
# Track the origin type
|
||||
if type_tracker:
|
||||
type_tracker.track_type(origin)
|
||||
|
||||
# Get the origin name
|
||||
origin_name = getattr(origin, "__name__", str(origin))
|
||||
if "." in origin_name:
|
||||
origin_name = origin_name.split(".")[-1]
|
||||
|
||||
# Special handling for types.UnionType (Python 3.10+ pipe operator)
|
||||
# Convert to old-style Union for compatibility
|
||||
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
|
||||
origin_name = "Union"
|
||||
|
||||
# Format arguments recursively
|
||||
if args:
|
||||
formatted_args = []
|
||||
for arg in args:
|
||||
# Track each type in the union
|
||||
if type_tracker:
|
||||
type_tracker.track_type(arg)
|
||||
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
|
||||
return f"{origin_name}[{', '.join(formatted_args)}]"
|
||||
else:
|
||||
return origin_name
|
||||
except (AttributeError, TypeError):
|
||||
# Fallback for older Python versions or non-generic types
|
||||
pass
|
||||
|
||||
# Handle generic types the old way for compatibility
|
||||
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
|
||||
origin = annotation.__origin__
|
||||
origin_name = (
|
||||
origin.__name__
|
||||
if hasattr(origin, "__name__")
|
||||
else str(origin).split("'")[1]
|
||||
)
|
||||
|
||||
# Format each type argument
|
||||
args = []
|
||||
for arg in annotation.__args__:
|
||||
args.append(cls._format_type_annotation(arg, type_tracker))
|
||||
|
||||
return f"{origin_name}[{', '.join(args)}]"
|
||||
|
||||
# Handle regular types with __name__
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
|
||||
# Handle special module types (like types from typing module)
|
||||
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
|
||||
# For types like typing.Literal, typing.TypedDict, etc.
|
||||
return annotation.__qualname__
|
||||
|
||||
# Last resort: string conversion with cleanup
|
||||
type_str = str(annotation)
|
||||
|
||||
# Clean up common patterns more robustly
|
||||
if type_str.startswith("<class '") and type_str.endswith("'>"):
|
||||
type_str = type_str[8:-2] # Remove "<class '" and "'>"
|
||||
|
||||
# Remove module prefixes for common modules
|
||||
for prefix in ["typing.", "builtins.", "types."]:
|
||||
if type_str.startswith(prefix):
|
||||
type_str = type_str[len(prefix) :]
|
||||
|
||||
# Handle special cases
|
||||
if type_str in ("_empty", "inspect._empty"):
|
||||
return "None"
|
||||
|
||||
# Fix NoneType (this should rarely be needed now)
|
||||
if type_str == "NoneType":
|
||||
return "None"
|
||||
|
||||
return type_str
|
||||
|
||||
@classmethod
|
||||
def _extract_coroutine_return_type(cls, annotation):
|
||||
"""Extract the actual return type from a Coroutine annotation."""
|
||||
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
|
||||
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
|
||||
return annotation.__args__[2]
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def _format_parameter_default(cls, default_value) -> str:
|
||||
"""Format a parameter's default value for stub files."""
|
||||
if default_value is inspect.Parameter.empty:
|
||||
return ""
|
||||
elif default_value is None:
|
||||
return " = None"
|
||||
elif isinstance(default_value, bool):
|
||||
return f" = {default_value}"
|
||||
elif default_value == {}:
|
||||
return " = {}"
|
||||
elif default_value == []:
|
||||
return " = []"
|
||||
else:
|
||||
return f" = {default_value}"
|
||||
|
||||
@classmethod
|
||||
def _format_method_parameters(
|
||||
cls,
|
||||
sig: inspect.Signature,
|
||||
skip_self: bool = True,
|
||||
type_hints: Optional[dict] = None,
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> str:
|
||||
"""Format method parameters for stub files."""
|
||||
params = []
|
||||
if type_hints is None:
|
||||
type_hints = {}
|
||||
|
||||
for i, (param_name, param) in enumerate(sig.parameters.items()):
|
||||
if i == 0 and param_name == "self" and skip_self:
|
||||
params.append("self")
|
||||
else:
|
||||
# Get type annotation from type hints if available, otherwise from signature
|
||||
annotation = type_hints.get(param_name, param.annotation)
|
||||
type_str = cls._format_type_annotation(annotation, type_tracker)
|
||||
|
||||
# Get default value
|
||||
default_str = cls._format_parameter_default(param.default)
|
||||
|
||||
# Combine parameter parts
|
||||
if annotation is inspect.Parameter.empty:
|
||||
params.append(f"{param_name}: Any{default_str}")
|
||||
else:
|
||||
params.append(f"{param_name}: {type_str}{default_str}")
|
||||
|
||||
return ", ".join(params)
|
||||
|
||||
@classmethod
|
||||
def _generate_method_signature(
|
||||
cls,
|
||||
method_name: str,
|
||||
method,
|
||||
is_async: bool = False,
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> str:
|
||||
"""Generate a complete method signature for stub files."""
|
||||
sig = inspect.signature(method)
|
||||
|
||||
# Try to get evaluated type hints to resolve string annotations
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
type_hints = get_type_hints(method)
|
||||
except Exception:
|
||||
# Fallback to empty dict if we can't get type hints
|
||||
type_hints = {}
|
||||
|
||||
# For async methods, extract the actual return type
|
||||
return_annotation = type_hints.get('return', sig.return_annotation)
|
||||
if is_async and inspect.iscoroutinefunction(method):
|
||||
return_annotation = cls._extract_coroutine_return_type(return_annotation)
|
||||
|
||||
# Format parameters with type hints
|
||||
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
|
||||
|
||||
# Format return type
|
||||
return_type = cls._format_type_annotation(return_annotation, type_tracker)
|
||||
if return_annotation is inspect.Signature.empty:
|
||||
return_type = "None"
|
||||
|
||||
return f"def {method_name}({params_str}) -> {return_type}: ..."
|
||||
|
||||
@classmethod
|
||||
def _generate_imports(
|
||||
cls, async_class: Type, type_tracker: TypeTracker
|
||||
) -> list[str]:
|
||||
"""Generate import statements for the stub file."""
|
||||
imports = []
|
||||
|
||||
# Add standard typing imports
|
||||
imports.append(
|
||||
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
|
||||
)
|
||||
|
||||
# Add imports from the original module
|
||||
if async_class.__module__ != "builtins":
|
||||
module = inspect.getmodule(async_class)
|
||||
additional_types = []
|
||||
|
||||
if module:
|
||||
# Check if module has __all__ defined
|
||||
module_all = getattr(module, "__all__", None)
|
||||
|
||||
for name, obj in sorted(inspect.getmembers(module)):
|
||||
if isinstance(obj, type):
|
||||
# Skip if __all__ is defined and this name isn't in it
|
||||
# unless it's already been tracked as used in type annotations
|
||||
if module_all is not None and name not in module_all:
|
||||
# Check if this type was actually used in annotations
|
||||
if name not in type_tracker.discovered_types:
|
||||
continue
|
||||
|
||||
# Check for NamedTuple
|
||||
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
|
||||
additional_types.append(name)
|
||||
# Mark as already imported
|
||||
type_tracker.already_imported.add(name)
|
||||
# Check for Enum
|
||||
elif issubclass(obj, Enum) and name != "Enum":
|
||||
additional_types.append(name)
|
||||
# Mark as already imported
|
||||
type_tracker.already_imported.add(name)
|
||||
|
||||
if additional_types:
|
||||
type_imports = ", ".join([async_class.__name__] + additional_types)
|
||||
imports.append(f"from {async_class.__module__} import {type_imports}")
|
||||
else:
|
||||
imports.append(
|
||||
f"from {async_class.__module__} import {async_class.__name__}"
|
||||
)
|
||||
|
||||
# Add imports for all discovered types
|
||||
# Pass the main module name to avoid duplicate imports
|
||||
imports.extend(
|
||||
type_tracker.get_imports(main_module_name=async_class.__module__)
|
||||
)
|
||||
|
||||
# Add base module import if needed
|
||||
if hasattr(inspect.getmodule(async_class), "__name__"):
|
||||
module_name = inspect.getmodule(async_class).__name__
|
||||
if "." in module_name:
|
||||
base_module = module_name.split(".")[0]
|
||||
# Only add if not already importing from it
|
||||
if not any(imp.startswith(f"from {base_module}") for imp in imports):
|
||||
imports.append(f"import {base_module}")
|
||||
|
||||
return imports
|
||||
|
||||
@classmethod
|
||||
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
# Look for class attributes that are classes
|
||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||
if isinstance(attr, type) and not name.startswith("_"):
|
||||
class_attributes.append((name, attr))
|
||||
elif (
|
||||
hasattr(async_class, "__annotations__")
|
||||
and name in async_class.__annotations__
|
||||
):
|
||||
annotation = async_class.__annotations__[name]
|
||||
if isinstance(annotation, type):
|
||||
class_attributes.append((name, annotation))
|
||||
|
||||
return class_attributes
|
||||
|
||||
@classmethod
|
||||
def _generate_inner_class_stub(
|
||||
cls,
|
||||
name: str,
|
||||
attr: Type,
|
||||
indent: str = " ",
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> list[str]:
|
||||
"""Generate stub for an inner class."""
|
||||
stub_lines = []
|
||||
stub_lines.append(f"{indent}class {name}Sync:")
|
||||
|
||||
# Add docstring if available
|
||||
if hasattr(attr, "__doc__") and attr.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
|
||||
)
|
||||
|
||||
# Add __init__ if it exists
|
||||
if hasattr(attr, "__init__"):
|
||||
try:
|
||||
init_method = getattr(attr, "__init__")
|
||||
init_sig = inspect.signature(init_method)
|
||||
|
||||
# Try to get type hints
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
init_hints = get_type_hints(init_method)
|
||||
except Exception:
|
||||
init_hints = {}
|
||||
|
||||
# Format parameters
|
||||
params_str = cls._format_method_parameters(
|
||||
init_sig, type_hints=init_hints, type_tracker=type_tracker
|
||||
)
|
||||
# Add __init__ docstring if available (before the method)
|
||||
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(
|
||||
init_method.__doc__, f"{indent} "
|
||||
)
|
||||
)
|
||||
stub_lines.append(
|
||||
f"{indent} def __init__({params_str}) -> None: ..."
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
stub_lines.append(
|
||||
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
|
||||
)
|
||||
|
||||
# Add methods to the inner class
|
||||
has_methods = False
|
||||
for method_name, method in sorted(
|
||||
inspect.getmembers(attr, predicate=inspect.isfunction)
|
||||
):
|
||||
if method_name.startswith("_"):
|
||||
continue
|
||||
|
||||
has_methods = True
|
||||
try:
|
||||
# Add method docstring if available (before the method signature)
|
||||
if method.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
|
||||
)
|
||||
|
||||
method_sig = cls._generate_method_signature(
|
||||
method_name, method, is_async=True, type_tracker=type_tracker
|
||||
)
|
||||
stub_lines.append(f"{indent} {method_sig}")
|
||||
except (ValueError, TypeError):
|
||||
stub_lines.append(
|
||||
f"{indent} def {method_name}(self, *args, **kwargs): ..."
|
||||
)
|
||||
|
||||
if not has_methods:
|
||||
stub_lines.append(f"{indent} pass")
|
||||
|
||||
return stub_lines
|
||||
|
||||
@classmethod
|
||||
def _format_docstring_for_stub(
|
||||
cls, docstring: str, indent: str = " "
|
||||
) -> list[str]:
|
||||
"""Format a docstring for inclusion in a stub file with proper indentation."""
|
||||
if not docstring:
|
||||
return []
|
||||
|
||||
# First, dedent the docstring to remove any existing indentation
|
||||
dedented = textwrap.dedent(docstring).strip()
|
||||
|
||||
# Split into lines
|
||||
lines = dedented.split("\n")
|
||||
|
||||
# Build the properly indented docstring
|
||||
result = []
|
||||
result.append(f'{indent}"""')
|
||||
|
||||
for line in lines:
|
||||
if line.strip(): # Non-empty line
|
||||
result.append(f"{indent}{line}")
|
||||
else: # Empty line
|
||||
result.append("")
|
||||
|
||||
result.append(f'{indent}"""')
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
|
||||
"""Post-process stub content to fix any remaining issues."""
|
||||
processed = []
|
||||
|
||||
for line in stub_content:
|
||||
# Skip processing imports
|
||||
if line.startswith(("from ", "import ")):
|
||||
processed.append(line)
|
||||
continue
|
||||
|
||||
# Fix method signatures missing return types
|
||||
if (
|
||||
line.strip().startswith("def ")
|
||||
and line.strip().endswith(": ...")
|
||||
and ") -> " not in line
|
||||
):
|
||||
# Add -> None for methods without return annotation
|
||||
line = line.replace(": ...", " -> None: ...")
|
||||
|
||||
processed.append(line)
|
||||
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||
"""
|
||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||
"""
|
||||
try:
|
||||
# Only generate stub if we can determine module path
|
||||
if async_class.__module__ == "__main__":
|
||||
return
|
||||
|
||||
module = inspect.getmodule(async_class)
|
||||
if not module:
|
||||
return
|
||||
|
||||
module_path = module.__file__
|
||||
if not module_path:
|
||||
return
|
||||
|
||||
# Create stub file path in a 'generated' subdirectory
|
||||
module_dir = os.path.dirname(module_path)
|
||||
stub_dir = os.path.join(module_dir, "generated")
|
||||
|
||||
# Ensure the generated directory exists
|
||||
os.makedirs(stub_dir, exist_ok=True)
|
||||
|
||||
module_name = os.path.basename(module_path)
|
||||
if module_name.endswith(".py"):
|
||||
module_name = module_name[:-3]
|
||||
|
||||
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
|
||||
|
||||
# Create a type tracker for this stub generation
|
||||
type_tracker = TypeTracker()
|
||||
|
||||
stub_content = []
|
||||
|
||||
# We'll generate imports after processing all methods to capture all types
|
||||
# Leave a placeholder for imports
|
||||
imports_placeholder_index = len(stub_content)
|
||||
stub_content.append("") # Will be replaced with imports later
|
||||
|
||||
# Class definition
|
||||
stub_content.append(f"class {sync_class.__name__}:")
|
||||
|
||||
# Docstring
|
||||
if async_class.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(async_class.__doc__, " ")
|
||||
)
|
||||
|
||||
# Generate __init__
|
||||
try:
|
||||
init_method = async_class.__init__
|
||||
init_signature = inspect.signature(init_method)
|
||||
|
||||
# Try to get type hints for __init__
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
init_hints = get_type_hints(init_method)
|
||||
except Exception:
|
||||
init_hints = {}
|
||||
|
||||
# Format parameters
|
||||
params_str = cls._format_method_parameters(
|
||||
init_signature, type_hints=init_hints, type_tracker=type_tracker
|
||||
)
|
||||
# Add __init__ docstring if available (before the method)
|
||||
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(init_method.__doc__, " ")
|
||||
)
|
||||
stub_content.append(f" def __init__({params_str}) -> None: ...")
|
||||
except (ValueError, TypeError):
|
||||
stub_content.append(
|
||||
" def __init__(self, *args, **kwargs) -> None: ..."
|
||||
)
|
||||
|
||||
stub_content.append("") # Add newline after __init__
|
||||
|
||||
# Get class attributes
|
||||
class_attributes = cls._get_class_attributes(async_class)
|
||||
|
||||
# Generate inner classes
|
||||
for name, attr in class_attributes:
|
||||
inner_class_stub = cls._generate_inner_class_stub(
|
||||
name, attr, type_tracker=type_tracker
|
||||
)
|
||||
stub_content.extend(inner_class_stub)
|
||||
stub_content.append("") # Add newline after the inner class
|
||||
|
||||
# Add methods to the main class
|
||||
processed_methods = set() # Keep track of methods we've processed
|
||||
for name, method in sorted(
|
||||
inspect.getmembers(async_class, predicate=inspect.isfunction)
|
||||
):
|
||||
if name.startswith("_") or name in processed_methods:
|
||||
continue
|
||||
|
||||
processed_methods.add(name)
|
||||
|
||||
try:
|
||||
method_sig = cls._generate_method_signature(
|
||||
name, method, is_async=True, type_tracker=type_tracker
|
||||
)
|
||||
|
||||
# Add docstring if available (before the method signature for proper formatting)
|
||||
if method.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(method.__doc__, " ")
|
||||
)
|
||||
|
||||
stub_content.append(f" {method_sig}")
|
||||
|
||||
stub_content.append("") # Add newline after each method
|
||||
|
||||
except (ValueError, TypeError):
|
||||
# If we can't get the signature, just add a simple stub
|
||||
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
|
||||
stub_content.append("") # Add newline
|
||||
|
||||
# Add properties
|
||||
for name, prop in sorted(
|
||||
inspect.getmembers(async_class, lambda x: isinstance(x, property))
|
||||
):
|
||||
stub_content.append(" @property")
|
||||
stub_content.append(f" def {name}(self) -> Any: ...")
|
||||
if prop.fset:
|
||||
stub_content.append(f" @{name}.setter")
|
||||
stub_content.append(
|
||||
f" def {name}(self, value: Any) -> None: ..."
|
||||
)
|
||||
stub_content.append("") # Add newline after each property
|
||||
|
||||
# Add placeholders for the nested class instances
|
||||
# Check the actual attribute names from class annotations and attributes
|
||||
attribute_mappings = {}
|
||||
|
||||
# First check annotations for typed attributes (including from parent classes)
|
||||
# Collect all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||
for class_name, class_type in class_attributes:
|
||||
# If the class type matches the annotated type
|
||||
if (
|
||||
attr_type == class_type
|
||||
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
|
||||
or (isinstance(attr_type, str) and attr_type == class_name)
|
||||
):
|
||||
attribute_mappings[class_name] = attr_name
|
||||
|
||||
# Remove the extra checking - annotations should be sufficient
|
||||
|
||||
# Add the attribute declarations with proper names
|
||||
for class_name, class_type in class_attributes:
|
||||
# Check if there's a mapping from annotation
|
||||
attr_name = attribute_mappings.get(class_name, class_name)
|
||||
# Use the annotation name if it exists, even if the attribute doesn't exist yet
|
||||
# This is because the attribute might be created at runtime
|
||||
stub_content.append(f" {attr_name}: {class_name}Sync")
|
||||
|
||||
stub_content.append("") # Add a final newline
|
||||
|
||||
# Now generate imports with all discovered types
|
||||
imports = cls._generate_imports(async_class, type_tracker)
|
||||
|
||||
# Deduplicate imports while preserving order
|
||||
seen = set()
|
||||
unique_imports = []
|
||||
for imp in imports:
|
||||
if imp not in seen:
|
||||
seen.add(imp)
|
||||
unique_imports.append(imp)
|
||||
else:
|
||||
logging.warning(f"Duplicate import detected: {imp}")
|
||||
|
||||
# Replace the placeholder with actual imports
|
||||
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
|
||||
unique_imports
|
||||
)
|
||||
|
||||
# Post-process stub content
|
||||
stub_content = cls._post_process_stub_content(stub_content)
|
||||
|
||||
# Write stub file
|
||||
with open(sync_stub_path, "w") as f:
|
||||
f.write("\n".join(stub_content))
|
||||
|
||||
logging.info(f"Generated stub file: {sync_stub_path}")
|
||||
|
||||
except Exception as e:
|
||||
# If stub generation fails, log the error but don't break the main functionality
|
||||
logging.error(
|
||||
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a sync version of an async class
|
||||
|
||||
Args:
|
||||
async_class: The async class to convert
|
||||
thread_pool_size: Size of thread pool to use
|
||||
|
||||
Returns:
|
||||
A new class with sync versions of all async methods
|
||||
"""
|
||||
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)
|
||||
@@ -1,33 +0,0 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
class SingletonMetaclass(type):
|
||||
T = TypeVar("T", bound="SingletonMetaclass")
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cls._instances[cls]
|
||||
|
||||
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||
assert cls not in SingletonMetaclass._instances, (
|
||||
"Cannot inject instance after first instantiation"
|
||||
)
|
||||
SingletonMetaclass._instances[cls] = instance
|
||||
|
||||
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||
"""
|
||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||
"""
|
||||
if cls not in SingletonMetaclass._instances:
|
||||
SingletonMetaclass._instances[cls] = super(
|
||||
SingletonMetaclass, cls
|
||||
).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -1,124 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
from comfy.cli_args import args
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ComfyAPI_latest(ComfyAPIBase):
|
||||
VERSION = "latest"
|
||||
STABLE = False
|
||||
|
||||
class Execution(ProxiedSingleton):
|
||||
async def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: str | None = None,
|
||||
preview_image: Image.Image | ImageInput | None = None,
|
||||
ignore_size_limit: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
executing_context = get_executing_context()
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
if node_id is None:
|
||||
raise ValueError("node_id must be provided if not in executing context")
|
||||
|
||||
# Convert preview_image to PreviewImageTuple if needed
|
||||
to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image
|
||||
if to_display is not None:
|
||||
# First convert to PIL Image if needed
|
||||
if isinstance(to_display, ImageInput):
|
||||
# Convert ImageInput (torch.Tensor) to PIL Image
|
||||
# Handle tensor shape [B, H, W, C] -> get first image if batch
|
||||
tensor = to_display
|
||||
if len(tensor.shape) == 4:
|
||||
tensor = tensor[0]
|
||||
|
||||
# Convert to numpy array and scale to 0-255
|
||||
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
|
||||
to_display = Image.fromarray(image_np)
|
||||
|
||||
if isinstance(to_display, Image.Image):
|
||||
# Detect image format from PIL Image
|
||||
image_format = to_display.format if to_display.format else "JPEG"
|
||||
# Use None for preview_size if ignore_size_limit is True
|
||||
preview_size = None if ignore_size_limit else args.preview_size
|
||||
to_display = (image_format, to_display, preview_size)
|
||||
|
||||
get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
execution: Execution
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
"""
|
||||
Returns a list of nodes that this extension provides.
|
||||
"""
|
||||
|
||||
class Input:
|
||||
Image = ImageInput
|
||||
Audio = AudioInput
|
||||
Mask = MaskInput
|
||||
Latent = LatentInput
|
||||
Video = VideoInput
|
||||
|
||||
class InputImpl:
|
||||
VideoFromFile = VideoFromFile
|
||||
VideoFromComponents = VideoFromComponents
|
||||
|
||||
class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoComponents = VideoComponents
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
@@ -1,10 +0,0 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
@@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from typing import TypedDict, List, Optional
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
MaskInput = torch.Tensor
|
||||
"""
|
||||
A mask in format [B, H, W] where B is the batch size
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
|
||||
class LatentInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing latent input.
|
||||
"""
|
||||
|
||||
samples: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
|
||||
H is the height, and W is the width.
|
||||
"""
|
||||
|
||||
noise_mask: Optional[MaskInput]
|
||||
"""
|
||||
Optional noise mask tensor in the same format as samples.
|
||||
"""
|
||||
|
||||
batch_index: Optional[List[int]]
|
||||
@@ -1,85 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
loading the entire video into memory.
|
||||
|
||||
Returns:
|
||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||
|
||||
Default implementation creates a BytesIO buffer, but subclasses should
|
||||
override this for better performance when possible.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
self.save_to(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
# Default implementation - subclasses should override for better performance
|
||||
source = self.get_stream_source()
|
||||
with av.open(source, mode="r") as container:
|
||||
return container.format.name
|
||||
@@ -1,7 +0,0 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
||||
@@ -1,324 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,72 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class ResourceKey(ABC):
|
||||
Type = Any
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
class TorchDictFolderFilename(ResourceKey):
|
||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||
Type = dict[str, torch.Tensor]
|
||||
def __init__(self, folder_name: str, file_name: str):
|
||||
self.folder_name = folder_name
|
||||
self.file_name = file_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.folder_name, self.file_name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TorchDictFolderFilename):
|
||||
return False
|
||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.folder_name} -> {self.file_name}"
|
||||
|
||||
class Resources(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
pass
|
||||
|
||||
class ResourcesLocal(Resources):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.local_resources: dict[ResourceKey, Any] = {}
|
||||
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
cached = self.local_resources.get(key, None)
|
||||
if cached is not None:
|
||||
logging.info(f"Using cached resource '{key}'")
|
||||
return cached
|
||||
logging.info(f"Loading resource '{key}'")
|
||||
to_return = None
|
||||
if isinstance(key, TorchDictFolderFilename):
|
||||
if default is ...:
|
||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||
else:
|
||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||
if full_path is not None:
|
||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||
|
||||
if to_return is not None:
|
||||
self.local_resources[key] = to_return
|
||||
return to_return
|
||||
if default is not ...:
|
||||
return default
|
||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||
|
||||
|
||||
class _RESOURCES:
|
||||
ResourceKey = ResourceKey
|
||||
TorchDictFolderFilename = TorchDictFolderFilename
|
||||
Resources = Resources
|
||||
ResourcesLocal = ResourcesLocal
|
||||
@@ -1,457 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self["filename"]
|
||||
|
||||
@property
|
||||
def subfolder(self) -> str:
|
||||
return self["subfolder"]
|
||||
|
||||
@property
|
||||
def type(self) -> FolderType:
|
||||
return FolderType(self["type"])
|
||||
|
||||
|
||||
class SavedImages(_UIOutput):
|
||||
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
self.is_animated = is_animated
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
data = {"images": self.results}
|
||||
if self.is_animated:
|
||||
data["animated"] = (True,)
|
||||
return data
|
||||
|
||||
|
||||
class SavedAudios(_UIOutput):
|
||||
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||
def __init__(self, results: list[SavedResult]):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.results}
|
||||
|
||||
|
||||
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.input:
|
||||
return folder_paths.get_input_directory()
|
||||
if folder_type == FolderType.output:
|
||||
return folder_paths.get_output_directory()
|
||||
return folder_paths.get_temp_directory()
|
||||
|
||||
|
||||
class ImageSaveHelper:
|
||||
"""A helper class with static methods to handle image saving and metadata."""
|
||||
|
||||
@staticmethod
|
||||
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
|
||||
"""Converts a single torch tensor to a PIL Image."""
|
||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
|
||||
@staticmethod
|
||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
"prompt".encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
x.encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||
"""Creates EXIF metadata bytes for WebP images."""
|
||||
exif_data = pil_image.getexif()
|
||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||
return exif_data
|
||||
if cls.hidden.prompt is not None:
|
||||
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
inital_exif_tag = 0x010F # EXIF 0x010f = Make
|
||||
for key, value in cls.hidden.extra_pnginfo.items():
|
||||
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
|
||||
inital_exif_tag -= 1
|
||||
return exif_data
|
||||
|
||||
@staticmethod
|
||||
def save_images(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||
) -> list[SavedResult]:
|
||||
"""Saves a batch of images as individual PNG files."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
results = []
|
||||
metadata = ImageSaveHelper._create_png_metadata(cls)
|
||||
for batch_number, image_tensor in enumerate(images):
|
||||
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
"""Saves a batch of images and returns a UI object for the node output."""
|
||||
return SavedImages(
|
||||
ImageSaveHelper.save_images(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_png(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated PNG."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
save_path = os.path.join(full_output_folder, file)
|
||||
pil_images[0].save(
|
||||
save_path,
|
||||
pnginfo=metadata,
|
||||
compress_level=compress_level,
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_png_ui(
|
||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedImages:
|
||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_png(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_webp(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated WebP."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[0].save(
|
||||
os.path.join(full_output_folder, file),
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
exif=pil_exif,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_webp_ui(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedImages:
|
||||
"""Saves an animated WebP and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_webp(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
|
||||
class AudioSaveHelper:
|
||||
"""A helper class with static methods to handle audio saving and metadata."""
|
||||
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
@staticmethod
|
||||
def save_audio(
|
||||
audio: dict,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
format: str = "flac",
|
||||
quality: str = "128k",
|
||||
) -> list[SavedResult]:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type)
|
||||
)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata and cls is not None:
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
results = []
|
||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(AudioSaveHelper._OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format=format)
|
||||
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: # format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_audio_ui(
|
||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
) -> SavedAudios:
|
||||
"""Save and instantly wrap for UI."""
|
||||
return SavedAudios(
|
||||
AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
format=format,
|
||||
quality=quality,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
compress_level=1,
|
||||
)
|
||||
self.animated = animated
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"images": self.values,
|
||||
"animated": (self.animated,)
|
||||
}
|
||||
|
||||
|
||||
class PreviewMask(PreviewImage):
|
||||
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
|
||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
super().__init__(preview, animated, cls, **kwargs)
|
||||
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
format="flac",
|
||||
quality="128k",
|
||||
)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.values}
|
||||
|
||||
|
||||
class PreviewVideo(_UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
|
||||
def as_dict(self):
|
||||
return {"images": self.values, "animated": (True,)}
|
||||
|
||||
|
||||
class PreviewUI3D(_UIOutput):
|
||||
def __init__(self, model_file, camera_info, **kwargs):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
def __init__(self, value: str, **kwargs):
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return {"text": (self.value,)}
|
||||
|
||||
|
||||
class _UI:
|
||||
SavedResult = SavedResult
|
||||
SavedImages = SavedImages
|
||||
SavedAudios = SavedAudios
|
||||
ImageSaveHelper = ImageSaveHelper
|
||||
AudioSaveHelper = AudioSaveHelper
|
||||
PreviewImage = PreviewImage
|
||||
PreviewMask = PreviewMask
|
||||
PreviewAudio = PreviewAudio
|
||||
PreviewVideo = PreviewVideo
|
||||
PreviewUI3D = PreviewUI3D
|
||||
PreviewText = PreviewText
|
||||
@@ -1,8 +0,0 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
||||
@@ -1,52 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
@@ -1,8 +0,0 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
"VideoCodec",
|
||||
"VideoContainer",
|
||||
"VideoComponents",
|
||||
]
|
||||
@@ -1,7 +1,7 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
|
||||
@@ -1,12 +1,51 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util.video_types import (
|
||||
VideoContainer,
|
||||
VideoCodec,
|
||||
VideoComponents,
|
||||
)
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
__all__ = [
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from comfy_api.v0_0_2 import (
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
Input as Input_v0_0_2,
|
||||
InputImpl as InputImpl_v0_0_2,
|
||||
Types as Types_v0_0_2,
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
|
||||
|
||||
# This version only exists to serve as a template for future version adapters.
|
||||
# There is no reason anyone should ever use it.
|
||||
class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2):
|
||||
VERSION = "0.0.1"
|
||||
STABLE = True
|
||||
|
||||
class Input(Input_v0_0_2):
|
||||
pass
|
||||
|
||||
class InputImpl(InputImpl_v0_0_2):
|
||||
pass
|
||||
|
||||
class Types(Types_v0_0_2):
|
||||
pass
|
||||
|
||||
ComfyAPI = ComfyAPIAdapter_v0_0_1
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||
|
||||
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
@@ -1,45 +0,0 @@
|
||||
from comfy_api.latest import (
|
||||
ComfyAPI_latest,
|
||||
Input as Input_latest,
|
||||
InputImpl as InputImpl_latest,
|
||||
Types as Types_latest,
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
VERSION = "0.0.2"
|
||||
STABLE = False
|
||||
|
||||
|
||||
class Input(Input_latest):
|
||||
pass
|
||||
|
||||
|
||||
class InputImpl(InputImpl_latest):
|
||||
pass
|
||||
|
||||
|
||||
class Types(Types_latest):
|
||||
pass
|
||||
|
||||
|
||||
ComfyAPI = ComfyAPIAdapter_v0_0_2
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
@@ -1,12 +0,0 @@
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from typing import List, Type
|
||||
|
||||
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||
ComfyAPI_latest,
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
ComfyAPIAdapter_v0_0_1,
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
## Introduction
|
||||
|
||||
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes).
|
||||
|
||||
## Development
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-07-06T09:47:31+00:00
|
||||
# timestamp: 2025-05-19T21:38:55+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1355,158 +1355,6 @@ class ModelResponseProperties(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class Keyframes(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyPromptResponse(BaseModel):
|
||||
error: Optional[Dict[str, Any]] = None
|
||||
frame_conditioning: Optional[Dict[str, Any]] = None
|
||||
id: Optional[str] = None
|
||||
inference_params: Optional[Dict[str, Any]] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
model_params: Optional[Dict[str, Any]] = None
|
||||
output_url: Optional[str] = None
|
||||
prompt_text: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
None, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
fps: Optional[int] = Field(
|
||||
24, description='Frames per second of the generated video'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
12.5, description='Guidance scale for generation control'
|
||||
)
|
||||
height: Optional[int] = Field(
|
||||
1080, description='Height of the generated video in pixels'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
None, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
width: Optional[int] = Field(
|
||||
1920, description='Width of the generated video in pixels'
|
||||
)
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoRequest(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
|
||||
prompt_text: Optional[str] = None
|
||||
webhook_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileRequest(BaseModel):
|
||||
file: Optional[StrictBytes] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileResponse(BaseModel):
|
||||
access_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
None, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
12.5, description='Guidance scale for generation control'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
None, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
|
||||
|
||||
class ControlType(str, Enum):
|
||||
motion_control = 'motion_control'
|
||||
pose_control = 'pose_control'
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoRequest(BaseModel):
|
||||
control_type: ControlType = Field(
|
||||
..., description='Supported types for video control'
|
||||
)
|
||||
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
|
||||
prompt_text: str = Field(..., description='Describes the video to generate')
|
||||
video_url: str = Field(..., description='Url to control video')
|
||||
webhook_url: Optional[str] = Field(
|
||||
None, description='Optional webhook URL for notifications'
|
||||
)
|
||||
|
||||
|
||||
class Moderation(str, Enum):
|
||||
low = 'low'
|
||||
auto = 'auto'
|
||||
@@ -3259,23 +3107,6 @@ class LumaUpscaleVideoGenerationRequest(BaseModel):
|
||||
resolution: Optional[LumaVideoModelOutputResolution] = None
|
||||
|
||||
|
||||
class MoonvalleyImageToVideoRequest(MoonvalleyTextToVideoRequest):
|
||||
keyframes: Optional[Dict[str, Keyframes]] = None
|
||||
|
||||
|
||||
class MoonvalleyResizeVideoRequest(MoonvalleyVideoToVideoRequest):
|
||||
frame_position: Optional[List[int]] = Field(None, max_length=2, min_length=2)
|
||||
frame_resolution: Optional[List[int]] = Field(None, max_length=2, min_length=2)
|
||||
scale: Optional[List[int]] = Field(None, max_length=2, min_length=2)
|
||||
|
||||
|
||||
class MoonvalleyTextToImageRequest(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
|
||||
prompt_text: Optional[str] = None
|
||||
webhook_url: Optional[str] = None
|
||||
|
||||
|
||||
class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]):
|
||||
root: Union[OutputTextContent, OutputAudioContent]
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
@@ -408,7 +406,7 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
|
||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||
mime_type = (
|
||||
GeminiMimeType.application_pdf
|
||||
GeminiMimeType.pdf
|
||||
if file_path.endswith(".pdf")
|
||||
else GeminiMimeType.text_plain
|
||||
)
|
||||
|
||||
@@ -132,8 +132,6 @@ def poll_until_finished(
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
|
||||
@@ -1,743 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
import random
|
||||
import torch
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
get_image_dimensions,
|
||||
validate_image_dimensions,
|
||||
)
|
||||
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
MoonvalleyPromptResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
import av
|
||||
import io
|
||||
|
||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
|
||||
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
|
||||
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
|
||||
|
||||
MIN_WIDTH = 300
|
||||
MIN_HEIGHT = 300
|
||||
|
||||
MAX_WIDTH = 10000
|
||||
MAX_HEIGHT = 10000
|
||||
|
||||
MIN_VID_WIDTH = 300
|
||||
MIN_VID_HEIGHT = 300
|
||||
|
||||
MAX_VID_WIDTH = 10000
|
||||
MAX_VID_HEIGHT = 10000
|
||||
|
||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||
|
||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class MoonvalleyApiError(Exception):
|
||||
"""Base exception for Moonvalley API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||
"""Verifies that the initial response contains a task ID."""
|
||||
return bool(response.id)
|
||||
|
||||
|
||||
def validate_task_creation_response(response) -> None:
|
||||
if not is_valid_task_creation_response(response):
|
||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||
logging.error(error_msg)
|
||||
raise MoonvalleyApiError(error_msg)
|
||||
|
||||
|
||||
def get_video_from_response(response):
|
||||
video = response.output_url
|
||||
logging.info(
|
||||
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
|
||||
)
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Moonvalley video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response:
|
||||
return str(get_video_from_response(response))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
],
|
||||
max_poll_attempts=240, # 64 minutes with 16s interval
|
||||
poll_interval=16.0,
|
||||
failed_statuses=["error"],
|
||||
status_extractor=lambda response: (
|
||||
response.status if response and response.status else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
def validate_prompts(
|
||||
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
|
||||
):
|
||||
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
|
||||
if not prompt:
|
||||
raise ValueError("Positive prompt is empty")
|
||||
if len(prompt) > max_length:
|
||||
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
|
||||
if negative_prompt and len(negative_prompt) > max_length:
|
||||
raise ValueError(
|
||||
f"Negative prompt is too long: {len(negative_prompt)} characters"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
|
||||
# inference validation
|
||||
# T = num_frames
|
||||
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
|
||||
# with image conditioning: H*W must be divisible by 8192
|
||||
# without image conditioning: T divisible by 32
|
||||
if num_frames_in and not num_frames_in % 16 == 0:
|
||||
return False, ("The input video total frame count must be divisible by 16!")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
return False, (
|
||||
f"Height ({height}) and width ({width}) must be " "divisible by 8"
|
||||
)
|
||||
|
||||
if with_frame_conditioning:
|
||||
if (height * width) % 8192 != 0:
|
||||
return False, (
|
||||
f"Height * width ({height * width}) must be "
|
||||
"divisible by 8192 for frame conditioning"
|
||||
)
|
||||
else:
|
||||
if num_frames_in and not num_frames_in % 32 == 0:
|
||||
return False, ("The input video total frame count must be divisible by 32!")
|
||||
|
||||
|
||||
def validate_input_image(
|
||||
image: torch.Tensor, with_frame_conditioning: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validates the input image adheres to the expectations of the API:
|
||||
- The image resolution should not be less than 300*300px
|
||||
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
|
||||
|
||||
"""
|
||||
height, width = get_image_dimensions(image)
|
||||
validate_input_media(width, height, with_frame_conditioning)
|
||||
validate_image_dimensions(
|
||||
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
|
||||
)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
"""
|
||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||
|
||||
Args:
|
||||
video: Input video to validate
|
||||
|
||||
Returns:
|
||||
Validated and potentially trimmed video
|
||||
|
||||
Raises:
|
||||
ValueError: If video doesn't meet requirements
|
||||
MoonvalleyApiError: If video duration is too short
|
||||
"""
|
||||
width, height = _get_video_dimensions(video)
|
||||
_validate_video_dimensions(width, height)
|
||||
_validate_container_format(video)
|
||||
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
|
||||
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
"""Extracts video dimensions with error handling."""
|
||||
try:
|
||||
return video.get_dimensions()
|
||||
except Exception as e:
|
||||
logging.error("Error getting dimensions of video: %s", e)
|
||||
raise ValueError(f"Cannot get video dimensions: {e}") from e
|
||||
|
||||
|
||||
def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
||||
supported_resolutions = {
|
||||
(1920, 1080), (1080, 1920), (1152, 1152),
|
||||
(1536, 1152), (1152, 1536)
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_container_format(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
"""Validates video duration and trims to 5 seconds if needed."""
|
||||
duration = video.get_duration()
|
||||
_validate_minimum_duration(duration)
|
||||
return _trim_if_too_long(video, duration)
|
||||
|
||||
|
||||
def _validate_minimum_duration(duration: float) -> None:
|
||||
"""Ensures video is at least 5 seconds long."""
|
||||
if duration < 5:
|
||||
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
"""Trims video to 5 seconds if longer."""
|
||||
if duration > 5:
|
||||
return trim_video(video, 5)
|
||||
return video
|
||||
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream(
|
||||
"h264", rate=stream.average_rate
|
||||
)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info(
|
||||
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
|
||||
)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream(
|
||||
"aac", rate=stream.sample_rate
|
||||
)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info(
|
||||
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
||||
)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(
|
||||
f"Encoded {frame_count} video frames (target: {target_frames})"
|
||||
)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(f"Encoded {audio_frame_count} audio frames")
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
# --- BaseMoonvalleyVideoNode ---
|
||||
class BaseMoonvalleyVideoNode:
|
||||
def parseWidthHeightFromRes(self, resolution: str):
|
||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||
res_map = {
|
||||
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
|
||||
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
|
||||
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
||||
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
|
||||
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
||||
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||
}
|
||||
if resolution in res_map:
|
||||
return res_map[resolution]
|
||||
else:
|
||||
# Default to 1920x1080 if unknown
|
||||
return {"width": 1920, "height": 1080}
|
||||
|
||||
def parseControlParameter(self, value):
|
||||
control_map = {
|
||||
"Motion Transfer": "motion_control",
|
||||
"Canny": "canny_control",
|
||||
"Pose Transfer": "pose_control",
|
||||
"Depth": "depth_control",
|
||||
}
|
||||
if value in control_map:
|
||||
return control_map[value]
|
||||
else:
|
||||
return control_map["Motion Transfer"]
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": [
|
||||
"16:9 (1920 x 1080)",
|
||||
"9:16 (1080 x 1920)",
|
||||
"1:1 (1152 x 1152)",
|
||||
"4:3 (1440 x 1080)",
|
||||
"3:4 (1080 x 1440)",
|
||||
"21:9 (2560 x 1080)",
|
||||
],
|
||||
"default": "16:9 (1920 x 1080)",
|
||||
"tooltip": "Resolution of the output video",
|
||||
},
|
||||
),
|
||||
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
|
||||
"prompt_adherence": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"guidance_scale",
|
||||
default=7.0,
|
||||
step=1,
|
||||
min=1,
|
||||
max=20,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"seed",
|
||||
default=random.randint(0, 2**32 - 1),
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display="number",
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=True,
|
||||
),
|
||||
"steps": model_field_to_node_input(
|
||||
IO.INT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"steps",
|
||||
default=100,
|
||||
min=1,
|
||||
max=100,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
"optional": {
|
||||
"image": model_field_to_node_input(
|
||||
IO.IMAGE,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
"image_url",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = "api node/video/Moonvalley Marey"
|
||||
API_NODE = True
|
||||
|
||||
def generate(self, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
# --- MoonvalleyImg2VideoNode ---
|
||||
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return super().INPUT_TYPES()
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
||||
|
||||
def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
image = kwargs.get("image", None)
|
||||
if image is None:
|
||||
raise MoonvalleyApiError("image is required")
|
||||
|
||||
validate_input_image(image, True)
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=128,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
use_negative_prompts=True,
|
||||
)
|
||||
"""Upload image to comfy backend to have a URL available for further processing"""
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_IMG2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
# --- MoonvalleyVid2VidNode ---
|
||||
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
|
||||
multiline=True
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
|
||||
),
|
||||
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
"optional": {
|
||||
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
|
||||
"control_type": (
|
||||
["Motion Transfer", "Pose Transfer"],
|
||||
{"default": "Motion Transfer"},
|
||||
),
|
||||
"motion_intensity": (
|
||||
"INT",
|
||||
{
|
||||
"default": 100,
|
||||
"step": 1,
|
||||
"min": 0,
|
||||
"max": 100,
|
||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
|
||||
def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
video = kwargs.get("video")
|
||||
|
||||
if not video:
|
||||
raise MoonvalleyApiError("video is required")
|
||||
|
||||
video_url = ""
|
||||
if video:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||
|
||||
control_type = kwargs.get("control_type")
|
||||
motion_intensity = kwargs.get("motion_intensity")
|
||||
|
||||
"""Validate prompts and inference input"""
|
||||
validate_prompts(prompt, negative_prompt)
|
||||
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
if control_type == "Motion Transfer" and motion_intensity is not None:
|
||||
control_params['motion_intensity'] = motion_intensity
|
||||
|
||||
inference_params=MoonvalleyVideoToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
seed=kwargs.get("seed"),
|
||||
control_params=control_params
|
||||
)
|
||||
|
||||
control = self.parseControlParameter(control_type)
|
||||
|
||||
request = MoonvalleyVideoToVideoRequest(
|
||||
control_type=control,
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_VIDEO2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyVideoToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
|
||||
return (video,)
|
||||
|
||||
|
||||
# --- MoonvalleyTxt2VideoNode ---
|
||||
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
input_types = super().INPUT_TYPES()
|
||||
# Remove image-specific parameters
|
||||
for param in ["image"]:
|
||||
if param in input_types["optional"]:
|
||||
del input_types["optional"][param]
|
||||
return input_types
|
||||
|
||||
def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||
|
||||
inference_params=MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=128,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
)
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_TXT2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
|
||||
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
|
||||
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
|
||||
}
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
|
||||
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
|
||||
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
|
||||
}
|
||||
@@ -3,13 +3,9 @@ from typing import Type, Literal
|
||||
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||
ExecutionBlocker = ExecutionBlocker
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
@@ -243,15 +239,8 @@ class ExecutionList(TopologicalSort):
|
||||
return True
|
||||
return False
|
||||
|
||||
# If an available node is async, do that first.
|
||||
# This will execute the asynchronous function earlier, reducing the overall time.
|
||||
def is_async(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
for node_id in node_list:
|
||||
if is_output(node_id) or is_async(node_id):
|
||||
if is_output(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
@@ -297,3 +286,21 @@ class ExecutionList(TopologicalSort):
|
||||
del blocked_by[node_id]
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
return list(blocked_by.keys())
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
@@ -137,19 +137,3 @@ def add_graph_prefix(graph, outputs, prefix):
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict, Dict, Optional, Tuple
|
||||
from typing import TypedDict, Dict, Optional
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
from abc import ABC
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from protocol import BinaryEventTypes
|
||||
from comfy_api import feature_flags
|
||||
|
||||
PreviewImageTuple = Tuple[str, Image.Image, Optional[int]]
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
@@ -20,23 +13,19 @@ class NodeState(Enum):
|
||||
Finished = "finished"
|
||||
Error = "error"
|
||||
|
||||
|
||||
class NodeProgressState(TypedDict):
|
||||
"""
|
||||
A class to represent the state of a node's progress.
|
||||
"""
|
||||
|
||||
state: NodeState
|
||||
value: float
|
||||
max: float
|
||||
|
||||
|
||||
class ProgressHandler(ABC):
|
||||
"""
|
||||
Abstract base class for progress handlers.
|
||||
Progress handlers receive progress updates and display them in various ways.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
@@ -48,15 +37,8 @@ class ProgressHandler(ABC):
|
||||
"""Called when a node starts processing"""
|
||||
pass
|
||||
|
||||
def update_handler(
|
||||
self,
|
||||
node_id: str,
|
||||
value: float,
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
"""Called when a node's progress is updated"""
|
||||
pass
|
||||
|
||||
@@ -76,12 +58,10 @@ class ProgressHandler(ABC):
|
||||
"""Disable this handler"""
|
||||
self.enabled = False
|
||||
|
||||
|
||||
class CLIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that displays progress using tqdm progress bars in the CLI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("cli")
|
||||
self.progress_bars: Dict[str, tqdm] = {}
|
||||
@@ -95,19 +75,12 @@ class CLIProgressHandler(ProgressHandler):
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars),
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
|
||||
@override
|
||||
def update_handler(
|
||||
self,
|
||||
node_id: str,
|
||||
value: float,
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Handle case where start_handler wasn't called
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
@@ -115,7 +88,7 @@ class CLIProgressHandler(ProgressHandler):
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars),
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
self.progress_bars[node_id].update(value)
|
||||
else:
|
||||
@@ -146,12 +119,10 @@ class CLIProgressHandler(ProgressHandler):
|
||||
bar.close()
|
||||
self.progress_bars.clear()
|
||||
|
||||
|
||||
class WebUIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that sends progress updates to the WebUI via WebSockets.
|
||||
"""
|
||||
|
||||
def __init__(self, server_instance):
|
||||
super().__init__("webui")
|
||||
self.server_instance = server_instance
|
||||
@@ -174,16 +145,17 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
for node_id, state in nodes.items()
|
||||
if state["state"] != NodeState.Pending
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
|
||||
)
|
||||
self.server_instance.send_sync("progress_state", {
|
||||
"prompt_id": prompt_id,
|
||||
"nodes": active_nodes
|
||||
})
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
@@ -192,41 +164,21 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
@override
|
||||
def update_handler(
|
||||
self,
|
||||
node_id: str,
|
||||
value: float,
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
if image:
|
||||
# Only send new format if client supports it
|
||||
if feature_flags.supports_feature(
|
||||
self.server_instance.sockets_metadata,
|
||||
self.server_instance.client_id,
|
||||
"supports_preview_metadata",
|
||||
):
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||
node_id
|
||||
),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(
|
||||
node_id
|
||||
),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||
}
|
||||
self.server_instance.send_sync(
|
||||
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
|
||||
(image, metadata),
|
||||
self.server_instance.client_id,
|
||||
)
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id)
|
||||
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
@@ -238,8 +190,7 @@ class ProgressRegistry:
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.nodes: Dict[str, NodeProgressState] = {}
|
||||
@@ -270,7 +221,9 @@ class ProgressRegistry:
|
||||
"""Ensure a node entry exists"""
|
||||
if node_id not in self.nodes:
|
||||
self.nodes[node_id] = NodeProgressState(
|
||||
state=NodeState.Pending, value=0, max=1
|
||||
state = NodeState.Pending,
|
||||
value = 0,
|
||||
max = 1
|
||||
)
|
||||
return self.nodes[node_id]
|
||||
|
||||
@@ -286,9 +239,7 @@ class ProgressRegistry:
|
||||
if handler.enabled:
|
||||
handler.start_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def update_progress(
|
||||
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None
|
||||
) -> None:
|
||||
def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None:
|
||||
"""Update progress for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
@@ -298,9 +249,7 @@ class ProgressRegistry:
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.update_handler(
|
||||
node_id, value, max_value, entry, self.prompt_id, image
|
||||
)
|
||||
handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image)
|
||||
|
||||
def finish_progress(self, node_id: str) -> None:
|
||||
"""Finish progress tracking for a node"""
|
||||
@@ -319,9 +268,9 @@ class ProgressRegistry:
|
||||
handler.reset()
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry | None = None
|
||||
global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None:
|
||||
global global_progress_registry
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
@@ -331,19 +280,9 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
registry = get_progress_state()
|
||||
handler.set_registry(registry)
|
||||
registry.register_handler(handler)
|
||||
|
||||
handler.set_registry(global_progress_registry)
|
||||
global_progress_registry.register_handler(handler)
|
||||
|
||||
def get_progress_state() -> ProgressRegistry:
|
||||
global global_progress_registry
|
||||
if global_progress_registry is None:
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
|
||||
global_progress_registry = ProgressRegistry(
|
||||
prompt_id="", dynprompt=DynamicPrompt({})
|
||||
)
|
||||
return global_progress_registry
|
||||
|
||||
@@ -133,6 +133,14 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create in-memory WAV buffer
|
||||
wav_buffer = io.BytesIO()
|
||||
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
|
||||
wav_buffer.seek(0) # Rewind for reading
|
||||
|
||||
# Use PyAV to convert and add metadata
|
||||
input_container = av.open(wav_buffer)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
@@ -142,6 +150,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Set up the output stream with appropriate properties
|
||||
input_container.streams.audio[0]
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
@@ -166,16 +175,18 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Copy frames from input to output
|
||||
for frame in input_container.decode(audio=0):
|
||||
frame.pts = None # Let PyAV handle timestamps
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
@@ -278,42 +289,6 @@ class PreviewAudio(SaveAudio):
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
def load(filepath: str) -> tuple[torch.Tensor, int]:
|
||||
with av.open(filepath) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in the file.")
|
||||
|
||||
stream = af.streams.audio[0]
|
||||
sr = stream.codec_context.sample_rate
|
||||
n_channels = stream.channels
|
||||
|
||||
frames = []
|
||||
length = 0
|
||||
for frame in af.decode(streams=stream.index):
|
||||
buf = torch.from_numpy(frame.to_ndarray())
|
||||
if buf.shape[0] != n_channels:
|
||||
buf = buf.view(-1, n_channels).t()
|
||||
|
||||
frames.append(buf)
|
||||
length += buf.shape[1]
|
||||
|
||||
if not frames:
|
||||
raise ValueError("No audio frames decoded.")
|
||||
|
||||
wav = torch.cat(frames, dim=1)
|
||||
wav = f32_pcm(wav)
|
||||
return wav, sr
|
||||
|
||||
class LoadAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -328,7 +303,7 @@ class LoadAudio:
|
||||
|
||||
def load(self, audio):
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
waveform, sample_rate = load(audio_path)
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
|
||||
|
||||
@@ -40,33 +40,6 @@ class CFGZeroStar:
|
||||
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
||||
return (m, )
|
||||
|
||||
class CFGNorm:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "advanced/guidance"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, strength):
|
||||
m = model.clone()
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CFGZeroStar": CFGZeroStar,
|
||||
"CFGNorm": CFGNorm,
|
||||
"CFGZeroStar": CFGZeroStar
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ import math
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.k_diffusion import sa_solver
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import latent_preview
|
||||
import torch
|
||||
@@ -301,35 +300,6 @@ class ExtendIntermediateSigmas:
|
||||
|
||||
return (extended_sigmas,)
|
||||
|
||||
|
||||
class SamplingPercentToSigma:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
"sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}),
|
||||
"return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.FLOAT,)
|
||||
RETURN_NAMES = ("sigma_value",)
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "get_sigma"
|
||||
|
||||
def get_sigma(self, model, sampling_percent, return_actual_sigma):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_val = model_sampling.percent_to_sigma(sampling_percent)
|
||||
if return_actual_sigma:
|
||||
if sampling_percent == 0.0:
|
||||
sigma_val = model_sampling.sigma_max.item()
|
||||
elif sampling_percent == 1.0:
|
||||
sigma_val = model_sampling.sigma_min.item()
|
||||
return (sigma_val,)
|
||||
|
||||
|
||||
class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -551,49 +521,6 @@ class SamplerER_SDE(ComfyNodeABC):
|
||||
return (sampler,)
|
||||
|
||||
|
||||
class SamplerSASolver(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
"eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},),
|
||||
"sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},),
|
||||
"sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},),
|
||||
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},),
|
||||
"predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}),
|
||||
"corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}),
|
||||
"use_pece": (IO.BOOLEAN, {}),
|
||||
"simple_order_2": (IO.BOOLEAN, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.SAMPLER,)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
start_sigma = model_sampling.percent_to_sigma(sde_start_percent)
|
||||
end_sigma = model_sampling.percent_to_sigma(sde_end_percent)
|
||||
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta)
|
||||
|
||||
sampler_name = "sa_solver"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
sampler_name,
|
||||
{
|
||||
"tau_func": tau_func,
|
||||
"s_noise": s_noise,
|
||||
"predictor_order": predictor_order,
|
||||
"corrector_order": corrector_order,
|
||||
"use_pece": use_pece,
|
||||
"simple_order_2": simple_order_2,
|
||||
},
|
||||
)
|
||||
return (sampler,)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@@ -712,10 +639,9 @@ class CFGGuider:
|
||||
return (guider,)
|
||||
|
||||
class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
def set_cfg(self, cfg1, cfg2, nested=False):
|
||||
def set_cfg(self, cfg1, cfg2):
|
||||
self.cfg1 = cfg1
|
||||
self.cfg2 = cfg2
|
||||
self.nested = nested
|
||||
|
||||
def set_conds(self, positive, middle, negative):
|
||||
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
|
||||
@@ -725,20 +651,14 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
middle_cond = self.conds.get("middle", None)
|
||||
positive_cond = self.conds.get("positive", None)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.cfg2, 1.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg1, 1.0):
|
||||
middle_cond = None
|
||||
|
||||
if self.nested:
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond)
|
||||
return out[0] + self.cfg2 * (pred_text - out[0])
|
||||
else:
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.cfg2, 1.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg1, 1.0):
|
||||
middle_cond = None
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
|
||||
class DualCFGGuider:
|
||||
@classmethod
|
||||
@@ -750,7 +670,6 @@ class DualCFGGuider:
|
||||
"negative": ("CONDITIONING", ),
|
||||
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"style": (["regular", "nested"],),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -759,10 +678,10 @@ class DualCFGGuider:
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "sampling/custom_sampling/guiders"
|
||||
|
||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
|
||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
|
||||
guider = Guider_DualCFG(model)
|
||||
guider.set_conds(cond1, cond2, negative)
|
||||
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
|
||||
guider.set_cfg(cfg_conds, cfg_cond2_negative)
|
||||
return (guider,)
|
||||
|
||||
class DisableNoise:
|
||||
@@ -910,13 +829,11 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SamplerER_SDE": SamplerER_SDE,
|
||||
"SamplerSASolver": SamplerSASolver,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
"SetFirstSigma": SetFirstSigma,
|
||||
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
|
||||
"SamplingPercentToSigma": SamplingPercentToSigma,
|
||||
|
||||
"CFGGuider": CFGGuider,
|
||||
"DualCFGGuider": DualCFGGuider,
|
||||
|
||||
@@ -71,11 +71,8 @@ class FreSca:
|
||||
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
|
||||
def patch(self, model, scale_low, scale_high, freq_cutoff):
|
||||
def custom_cfg_function(args):
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
return conds_out
|
||||
cond = conds_out[0]
|
||||
uncond = conds_out[1]
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
|
||||
guidance = cond - uncond
|
||||
filtered_guidance = Fourier_filter(
|
||||
@@ -86,7 +83,7 @@ class FreSca:
|
||||
)
|
||||
filtered_cond = filtered_guidance + uncond
|
||||
|
||||
return [filtered_cond, uncond] + conds_out[2:]
|
||||
return [filtered_cond, uncond]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
||||
|
||||
@@ -583,49 +583,6 @@ class GetImageSize:
|
||||
|
||||
return width, height, batch_size
|
||||
|
||||
class ImageRotate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "rotate"
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def rotate(self, image, rotation):
|
||||
rotate_by = 0
|
||||
if rotation.startswith("90"):
|
||||
rotate_by = 1
|
||||
elif rotation.startswith("180"):
|
||||
rotate_by = 2
|
||||
elif rotation.startswith("270"):
|
||||
rotate_by = 3
|
||||
|
||||
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
|
||||
return (image,)
|
||||
|
||||
class ImageFlip:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "flip"
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def flip(self, image, flip_method):
|
||||
if flip_method.startswith("x"):
|
||||
image = torch.flip(image, dims=[1])
|
||||
elif flip_method.startswith("y"):
|
||||
image = torch.flip(image, dims=[2])
|
||||
|
||||
return (image,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
@@ -637,6 +594,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ImageStitch": ImageStitch,
|
||||
"ResizeAndPadImage": ResizeAndPadImage,
|
||||
"GetImageSize": GetImageSize,
|
||||
"ImageRotate": ImageRotate,
|
||||
"ImageFlip": ImageFlip,
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import os
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
@@ -18,14 +16,7 @@ class Load3D():
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||
]
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
@@ -70,14 +61,7 @@ class Load3DAnimation():
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
|
||||
]
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.fbx'))]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
|
||||
@@ -134,8 +134,8 @@ class LTXVAddGuide:
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1 and frame_idx != 0:
|
||||
frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1 # frame index - 1 must be divisible by 8 or frame_idx == 0
|
||||
if guide_length > 1:
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
|
||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||
|
||||
@@ -144,7 +144,7 @@ class LTXVAddGuide:
|
||||
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
|
||||
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||
_, latent_coords = self._patchifier.patchify(guiding_latent)
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
|
||||
pixel_coords[:, 0] += frame_idx
|
||||
if keyframe_idxs is None:
|
||||
keyframe_idxs = pixel_coords
|
||||
|
||||
@@ -152,7 +152,7 @@ class ImageColorToMask:
|
||||
def image_to_mask(self, image, color):
|
||||
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
||||
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
||||
mask = torch.where(temp == color, 1.0, 0).float()
|
||||
mask = torch.where(temp == color, 255, 0).float()
|
||||
return (mask,)
|
||||
|
||||
class SolidMask:
|
||||
@@ -247,7 +247,7 @@ class MaskComposite:
|
||||
visible_width, visible_height = (right - left, bottom - top,)
|
||||
|
||||
source_portion = source[:, :visible_height, :visible_width]
|
||||
destination_portion = output[:, top:bottom, left:right]
|
||||
destination_portion = destination[:, top:bottom, left:right]
|
||||
|
||||
if operation == "multiply":
|
||||
output[:, top:bottom, left:right] = destination_portion * source_portion
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class CLIPTextEncodePixArtAlpha:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "advanced/conditioning"
|
||||
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
|
||||
|
||||
def encode(self, clip, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
|
||||
}
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class CLIPTextEncodePixArtAlpha:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "advanced/conditioning"
|
||||
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
|
||||
|
||||
def encode(self, clip, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
|
||||
}
|
||||
|
||||
@@ -78,75 +78,7 @@ class SkipLayerGuidanceDiT:
|
||||
|
||||
return (m, )
|
||||
|
||||
class SkipLayerGuidanceDiTSimple:
|
||||
'''
|
||||
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
||||
'''
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
double_layers = re.findall(r'\d+', double_layers)
|
||||
double_layers = [int(i) for i in double_layers]
|
||||
|
||||
single_layers = re.findall(r'\d+', single_layers)
|
||||
single_layers = [int(i) for i in single_layers]
|
||||
|
||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||
return (model, )
|
||||
|
||||
def calc_cond_batch_function(args):
|
||||
x = args["input"]
|
||||
model = args["model"]
|
||||
conds = args["conds"]
|
||||
sigma = args["sigma"]
|
||||
|
||||
model_options = args["model_options"]
|
||||
slg_model_options = model_options.copy()
|
||||
|
||||
for layer in double_layers:
|
||||
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "double_block", layer)
|
||||
|
||||
for layer in single_layers:
|
||||
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "single_block", layer)
|
||||
|
||||
cond, uncond = conds
|
||||
sigma_ = sigma[0].item()
|
||||
if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None:
|
||||
cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options)
|
||||
_, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options)
|
||||
out = [cond_out, uncond_out]
|
||||
else:
|
||||
out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options)
|
||||
|
||||
return out
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
||||
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
||||
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
|
||||
}
|
||||
|
||||
@@ -20,82 +20,41 @@ import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters, adapter_maps
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
new_dict = {}
|
||||
for k, v in d.items():
|
||||
newv = v
|
||||
if isinstance(v, dict):
|
||||
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if full_size is None or v.size(0) == full_size:
|
||||
newv = v[indicies]
|
||||
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
||||
newv = [v[i] for i in indicies]
|
||||
new_dict[k] = newv
|
||||
return new_dict
|
||||
from comfy.weight_adapter import adapters
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
self.batch_size = batch_size
|
||||
self.total_steps = total_steps
|
||||
self.grad_acc = grad_acc
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||
self.optimizer.zero_grad()
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
||||
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(sigmas),
|
||||
torch.zeros_like(noise, requires_grad=True),
|
||||
latent_image,
|
||||
False
|
||||
)
|
||||
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
# Ensure model is in training mode and computing gradients
|
||||
# x0 pred
|
||||
denoised = model_wrap(noise, sigmas, **extra_args)
|
||||
try:
|
||||
loss = self.loss_fn(denoised, latent.clone())
|
||||
except RuntimeError as e:
|
||||
if "does not require grad and does not have a grad_fn" in str(e):
|
||||
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
|
||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(batch_sigmas),
|
||||
torch.zeros_like(batch_noise),
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [
|
||||
cond[i] for i in indicies
|
||||
]
|
||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||
|
||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
if (i+1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
self.optimizer.step()
|
||||
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
|
||||
@@ -116,7 +75,7 @@ class BiasDiff(torch.nn.Module):
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
||||
"""Utility function to load and process a list of images.
|
||||
|
||||
Args:
|
||||
@@ -131,6 +90,7 @@ def load_and_process_images(image_files, input_dir, resize_method="None", w=None
|
||||
raise ValueError("No valid images found in input")
|
||||
|
||||
output_images = []
|
||||
w, h = None, None
|
||||
|
||||
for file in image_files:
|
||||
image_path = os.path.join(input_dir, file)
|
||||
@@ -246,103 +206,6 @@ class LoadImageSetFromFolderNode:
|
||||
return (output_tensor,)
|
||||
|
||||
|
||||
class LoadImageTextSetFromFolderNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
|
||||
},
|
||||
"optional": {
|
||||
"resize_method": (
|
||||
["None", "Stretch", "Crop", "Pad"],
|
||||
{"default": "None"},
|
||||
),
|
||||
"width": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": -1,
|
||||
"min": -1,
|
||||
"max": 10000,
|
||||
"step": 1,
|
||||
"tooltip": "The width to resize the images to. -1 means use the original width.",
|
||||
},
|
||||
),
|
||||
"height": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": -1,
|
||||
"min": -1,
|
||||
"max": 10000,
|
||||
"step": 1,
|
||||
"tooltip": "The height to resize the images to. -1 means use the original height.",
|
||||
},
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
|
||||
FUNCTION = "load_images"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
|
||||
|
||||
def load_images(self, folder, clip, resize_method, width=None, height=None):
|
||||
if clip is None:
|
||||
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
||||
|
||||
logging.info(f"Loading images from folder: {folder}")
|
||||
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
|
||||
image_files = []
|
||||
for item in os.listdir(sub_input_dir):
|
||||
path = os.path.join(sub_input_dir, item)
|
||||
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
||||
image_files.append(path)
|
||||
elif os.path.isdir(path):
|
||||
# Support kohya-ss/sd-scripts folder structure
|
||||
repeat = 1
|
||||
if item.split("_")[0].isdigit():
|
||||
repeat = int(item.split("_")[0])
|
||||
image_files.extend([
|
||||
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
] * repeat)
|
||||
|
||||
caption_file_path = [
|
||||
f.replace(os.path.splitext(f)[1], ".txt")
|
||||
for f in image_files
|
||||
]
|
||||
captions = []
|
||||
for caption_file in caption_file_path:
|
||||
caption_path = os.path.join(sub_input_dir, caption_file)
|
||||
if os.path.exists(caption_path):
|
||||
with open(caption_path, "r", encoding="utf-8") as f:
|
||||
caption = f.read().strip()
|
||||
captions.append(caption)
|
||||
else:
|
||||
captions.append("")
|
||||
|
||||
width = width if width != -1 else None
|
||||
height = height if height != -1 else None
|
||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
|
||||
|
||||
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
||||
|
||||
logging.info(f"Encoding captions from {sub_input_dir}.")
|
||||
conditions = []
|
||||
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
||||
for text in captions:
|
||||
if text == "":
|
||||
conditions.append(empty_cond)
|
||||
tokens = clip.tokenize(text)
|
||||
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
|
||||
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
|
||||
return (output_tensor, conditions)
|
||||
|
||||
|
||||
def draw_loss_graph(loss_map, steps):
|
||||
width, height = 500, 300
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
@@ -420,16 +283,6 @@ class TrainLoraNode:
|
||||
"tooltip": "The batch size to use for training.",
|
||||
},
|
||||
),
|
||||
"grad_accumulation_steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 1024,
|
||||
"step": 1,
|
||||
"tooltip": "The number of gradient accumulation steps to use for training.",
|
||||
}
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
@@ -489,17 +342,6 @@ class TrainLoraNode:
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||
),
|
||||
"algorithm": (
|
||||
list(adapter_maps.keys()),
|
||||
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
|
||||
),
|
||||
"gradient_checkpointing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Use gradient checkpointing for training.",
|
||||
}
|
||||
),
|
||||
"existing_lora": (
|
||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
{
|
||||
@@ -523,7 +365,6 @@ class TrainLoraNode:
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
grad_accumulation_steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
@@ -531,8 +372,6 @@ class TrainLoraNode:
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
@@ -542,13 +381,6 @@ class TrainLoraNode:
|
||||
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
positive = positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
raise ValueError(
|
||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||
)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
@@ -583,8 +415,10 @@ class TrainLoraNode:
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
adapter_cls = adapters[0]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
@@ -638,45 +472,45 @@ class TrainLoraNode:
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
loss_map = {"loss": []}
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
pbar.set_postfix({"loss": f"{loss:.4f}"})
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps*grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
criterion, optimizer, loss_callback=loss_callback
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
||||
|
||||
# yoland: this currently resize to the first image in the dataset
|
||||
|
||||
# Training loop
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
# Generate dummy sigmas and noise
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed
|
||||
)
|
||||
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
# Generate random sigma
|
||||
sigma = mp.model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||
|
||||
indices = torch.randperm(num_images)[:batch_size]
|
||||
ss.sample(
|
||||
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
del ss, train_sampler, optimizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
@@ -863,7 +697,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SaveLoRANode": SaveLoRA,
|
||||
"LoraModelLoader": LoraModelLoader,
|
||||
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
||||
"LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
|
||||
"LossGraphNode": LossGraphNode,
|
||||
}
|
||||
|
||||
@@ -872,6 +705,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveLoRANode": "Save LoRA Weights",
|
||||
"LoraModelLoader": "Load LoRA Model",
|
||||
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
||||
"LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
|
||||
"LossGraphNode": "Plot Loss Graph",
|
||||
}
|
||||
|
||||
@@ -8,7 +8,9 @@ import json
|
||||
from typing import Optional, Literal
|
||||
from fractions import Fraction
|
||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||
from comfy_api.latest import Input, InputImpl, Types
|
||||
from comfy_api.input import ImageInput, AudioInput, VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM:
|
||||
@@ -89,8 +91,8 @@ class SaveVideo(ComfyNodeABC):
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
"format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
@@ -106,7 +108,7 @@ class SaveVideo(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
|
||||
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
@@ -125,7 +127,7 @@ class SaveVideo(ComfyNodeABC):
|
||||
metadata["prompt"] = prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=format,
|
||||
@@ -161,9 +163,9 @@ class CreateVideo(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Create a video from images."
|
||||
|
||||
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
|
||||
return (InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(
|
||||
def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None):
|
||||
return (VideoFromComponents(
|
||||
VideoComponents(
|
||||
images=images,
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
@@ -185,7 +187,7 @@ class GetVideoComponents(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||
|
||||
def get_components(self, video: Input.Video):
|
||||
def get_components(self, video: VideoInput):
|
||||
components = video.get_components()
|
||||
|
||||
return (components.images, components.audio, float(components.frame_rate))
|
||||
@@ -206,7 +208,7 @@ class LoadVideo(ComfyNodeABC):
|
||||
FUNCTION = "load_video"
|
||||
def load_video(self, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return (InputImpl.VideoFromFile(video_path),)
|
||||
return (VideoFromFile(video_path),)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, file):
|
||||
@@ -237,4 +239,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GetVideoComponents": "Get Video Components",
|
||||
"LoadVideo": "Load Video",
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
@@ -6,9 +5,7 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.latent_formats
|
||||
import comfy.clip_vision
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class WanImageToVideo:
|
||||
@classmethod
|
||||
@@ -149,7 +146,6 @@ class WanFirstLastFrameToVideo:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
clip_vision_output = None
|
||||
if clip_vision_start_image is not None:
|
||||
clip_vision_output = clip_vision_start_image
|
||||
|
||||
@@ -387,350 +383,7 @@ class WanPhantomSubjectToVideo:
|
||||
out_latent["samples"] = latent
|
||||
return (positive, cond2, negative, out_latent)
|
||||
|
||||
def parse_json_tracks(tracks):
|
||||
"""Parse JSON track data into a standardized format"""
|
||||
tracks_data = []
|
||||
try:
|
||||
# If tracks is a string, try to parse it as JSON
|
||||
if isinstance(tracks, str):
|
||||
parsed = json.loads(tracks.replace("'", '"'))
|
||||
tracks_data.extend(parsed)
|
||||
else:
|
||||
# If tracks is a list of strings, parse each one
|
||||
for track_str in tracks:
|
||||
parsed = json.loads(track_str.replace("'", '"'))
|
||||
tracks_data.append(parsed)
|
||||
|
||||
# Check if we have a single track (dict with x,y) or a list of tracks
|
||||
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
|
||||
# Single track detected, wrap it in a list
|
||||
tracks_data = [tracks_data]
|
||||
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
|
||||
# Already a list of tracks, nothing to do
|
||||
pass
|
||||
else:
|
||||
# Unexpected format
|
||||
pass
|
||||
|
||||
except json.JSONDecodeError:
|
||||
tracks_data = []
|
||||
return tracks_data
|
||||
|
||||
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs):
|
||||
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
|
||||
# frame_size: tuple (W, H)
|
||||
tracks = torch.from_numpy(tracks_np).float()
|
||||
|
||||
if tracks.shape[1] == 121:
|
||||
tracks = torch.permute(tracks, (1, 0, 2, 3))
|
||||
|
||||
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
|
||||
|
||||
short_edge = min(*frame_size)
|
||||
|
||||
frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2
|
||||
tracks = tracks - frame_center
|
||||
|
||||
tracks = tracks / short_edge * 2
|
||||
|
||||
visibles = visibles * 2 - 1
|
||||
|
||||
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
|
||||
|
||||
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
|
||||
|
||||
out_0 = out_[:1]
|
||||
|
||||
out_l = out_[1:] # 121 => 120 | 1
|
||||
a = 120 // math.gcd(120, num_frames)
|
||||
b = num_frames // math.gcd(120, num_frames)
|
||||
out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F
|
||||
|
||||
final_result = torch.cat([out_0, out_l], dim=0)
|
||||
|
||||
return final_result
|
||||
|
||||
FIXED_LENGTH = 121
|
||||
def pad_pts(tr):
|
||||
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
|
||||
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
|
||||
n = pts.shape[0]
|
||||
if n < FIXED_LENGTH:
|
||||
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
|
||||
pts = np.vstack((pts, pad))
|
||||
else:
|
||||
pts = pts[:FIXED_LENGTH]
|
||||
return pts.reshape(FIXED_LENGTH, 1, 3)
|
||||
|
||||
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
|
||||
"""Index selection utility function"""
|
||||
assert (
|
||||
len(ind.shape) > dim
|
||||
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
|
||||
|
||||
target = target.expand(
|
||||
*tuple(
|
||||
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
|
||||
+ [
|
||||
-1,
|
||||
]
|
||||
* (len(target.shape) - dim)
|
||||
)
|
||||
)
|
||||
|
||||
ind_pad = ind
|
||||
|
||||
if len(target.shape) > dim + 1:
|
||||
for _ in range(len(target.shape) - (dim + 1)):
|
||||
ind_pad = ind_pad.unsqueeze(-1)
|
||||
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :])
|
||||
|
||||
return torch.gather(target, dim=dim, index=ind_pad)
|
||||
|
||||
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
|
||||
"""Merge vertex attributes with weights"""
|
||||
target_dim = len(vert_assign.shape) - 1
|
||||
if len(vert_attr.shape) == 2:
|
||||
assert vert_attr.shape[0] > vert_assign.max()
|
||||
new_shape = [1] * target_dim + list(vert_attr.shape)
|
||||
tensor = vert_attr.reshape(new_shape)
|
||||
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
|
||||
else:
|
||||
assert vert_attr.shape[1] > vert_assign.max()
|
||||
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:])
|
||||
tensor = vert_attr.reshape(new_shape)
|
||||
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
|
||||
|
||||
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
|
||||
return final_attr
|
||||
|
||||
|
||||
def _patch_motion_single(
|
||||
tracks: torch.FloatTensor, # (B, T, N, 4)
|
||||
vid: torch.FloatTensor, # (C, T, H, W)
|
||||
temperature: float,
|
||||
vae_divide: tuple,
|
||||
topk: int,
|
||||
):
|
||||
"""Apply motion patching based on tracks"""
|
||||
_, T, H, W = vid.shape
|
||||
N = tracks.shape[2]
|
||||
_, tracks_xy, visible = torch.split(
|
||||
tracks, [1, 2, 1], dim=-1
|
||||
) # (B, T, N, 2) | (B, T, N, 1)
|
||||
tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device)
|
||||
tracks_n = tracks_n.clamp(-1, 1)
|
||||
visible = visible.clamp(0, 1)
|
||||
|
||||
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
|
||||
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
|
||||
|
||||
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
|
||||
tracks_xy.device
|
||||
)
|
||||
|
||||
tracks_pad = tracks_xy[:, 1:]
|
||||
visible_pad = visible[:, 1:]
|
||||
|
||||
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
|
||||
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
|
||||
1
|
||||
) / (visible_align + 1e-5)
|
||||
dist_ = (
|
||||
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
|
||||
) # T, H, W, N
|
||||
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
|
||||
T - 1, 1, 1, N
|
||||
)
|
||||
vert_weight, vert_index = torch.topk(
|
||||
weight, k=min(topk, weight.shape[-1]), dim=-1
|
||||
)
|
||||
|
||||
grid_mode = "bilinear"
|
||||
point_feature = torch.nn.functional.grid_sample(
|
||||
vid.permute(1, 0, 2, 3)[:1],
|
||||
tracks_n[:, :1].type(vid.dtype),
|
||||
mode=grid_mode,
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
)
|
||||
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
|
||||
|
||||
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
|
||||
out_weight = vert_weight.sum(-1) # T - 1, H, W
|
||||
|
||||
# out feature -> already soft weighted
|
||||
mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1))
|
||||
|
||||
out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W
|
||||
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
|
||||
|
||||
return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full
|
||||
|
||||
|
||||
def patch_motion(
|
||||
tracks: torch.FloatTensor, # (B, TB, T, N, 4)
|
||||
vid: torch.FloatTensor, # (C, T, H, W)
|
||||
temperature: float = 220.0,
|
||||
vae_divide: tuple = (4, 16),
|
||||
topk: int = 2,
|
||||
):
|
||||
B = len(tracks)
|
||||
|
||||
# Process each batch separately
|
||||
out_masks = []
|
||||
out_features = []
|
||||
|
||||
for b in range(B):
|
||||
mask, feature = _patch_motion_single(
|
||||
tracks[b], # (T, N, 4)
|
||||
vid[b], # (C, T, H, W)
|
||||
temperature,
|
||||
vae_divide,
|
||||
topk
|
||||
)
|
||||
out_masks.append(mask)
|
||||
out_features.append(feature)
|
||||
|
||||
# Stack results: (B, C, T, H, W)
|
||||
out_mask_full = torch.stack(out_masks, dim=0)
|
||||
out_feature_full = torch.stack(out_features, dim=0)
|
||||
|
||||
return out_mask_full, out_feature_full
|
||||
|
||||
class WanTrackToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
|
||||
"start_image": ("IMAGE", ),
|
||||
},
|
||||
"optional": {
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
|
||||
temperature, topk, start_image=None, clip_vision_output=None):
|
||||
|
||||
tracks_data = parse_json_tracks(tracks)
|
||||
|
||||
if not tracks_data:
|
||||
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
||||
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
|
||||
if isinstance(tracks_data[0][0], dict):
|
||||
tracks_data = [tracks_data]
|
||||
|
||||
processed_tracks = []
|
||||
for batch in tracks_data:
|
||||
arrs = []
|
||||
for track in batch:
|
||||
pts = pad_pts(track)
|
||||
arrs.append(pts)
|
||||
|
||||
tracks_np = np.stack(arrs, axis=0)
|
||||
processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0))
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
for i in range(start_image.shape[0]):
|
||||
videos[i, 0] = start_image[i]
|
||||
|
||||
latent_videos = []
|
||||
videos = comfy.utils.resize_to_batch_size(videos, batch_size)
|
||||
for i in range(batch_size):
|
||||
latent_videos += [vae.encode(videos[i, :, :, :, :3])]
|
||||
y = torch.cat(latent_videos, dim=0)
|
||||
|
||||
# Scale latent since patch_motion is non-linear
|
||||
y = comfy.latent_formats.Wan21().process_in(y)
|
||||
|
||||
processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size)
|
||||
res = patch_motion(
|
||||
processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16)
|
||||
)
|
||||
|
||||
mask, concat_latent_image = res
|
||||
concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image)
|
||||
mask = -mask + 1.0 # Invert mask to match expected format
|
||||
positive = node_helpers.conditioning_set_values(positive,
|
||||
{"concat_mask": mask,
|
||||
"concat_latent_image": concat_latent_image})
|
||||
negative = node_helpers.conditioning_set_values(negative,
|
||||
{"concat_mask": mask,
|
||||
"concat_latent_image": concat_latent_image})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
class Wan22ImageToVideoLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, vae, width, height, length, batch_size, start_image=None):
|
||||
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (out_latent,)
|
||||
|
||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
latent_temp = vae.encode(start_image)
|
||||
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
latent_format = comfy.latent_formats.Wan22()
|
||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanTrackToVideo": WanTrackToVideo,
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
@@ -739,5 +392,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
||||
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.48"
|
||||
__version__ = "0.3.43"
|
||||
|
||||
@@ -74,8 +74,7 @@ if not args.cuda_malloc:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
|
||||
if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
|
||||
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
except:
|
||||
pass
|
||||
|
||||
168
execution.py
168
execution.py
@@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
@@ -32,8 +32,6 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -58,15 +56,7 @@ class IsChangedCache:
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
has_is_changed = False
|
||||
is_changed_name = None
|
||||
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||
has_is_changed = True
|
||||
is_changed_name = "fingerprint_inputs"
|
||||
elif hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = True
|
||||
is_changed_name = "IS_CHANGED"
|
||||
if not has_is_changed:
|
||||
if not hasattr(class_def, "IS_CHANGED"):
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -75,9 +65,9 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
@@ -133,17 +123,10 @@ class CacheSet:
|
||||
}
|
||||
return result
|
||||
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
else:
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
hidden_inputs_v3 = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
@@ -168,37 +151,22 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if is_v3:
|
||||
if schema.hidden:
|
||||
if io.Hidden.prompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
if io.Hidden.dynprompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
if io.Hidden.unique_id in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys, hidden_inputs_v3
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@@ -214,7 +182,7 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -244,22 +212,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# V1
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
@@ -311,8 +264,8 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@@ -343,26 +296,6 @@ def get_output_from_returns(return_values, obj):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
elif isinstance(r, _NodeOutputInternal):
|
||||
# V3
|
||||
if r.ui is not None:
|
||||
if isinstance(r.ui, dict):
|
||||
uis.append(r.ui)
|
||||
else:
|
||||
uis.append(r.ui.as_dict())
|
||||
if r.expand is not None:
|
||||
has_subgraph = True
|
||||
new_graph = r.expand
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif r.result is not None:
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
@@ -446,7 +379,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -456,12 +389,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
@@ -493,7 +422,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
@@ -741,14 +670,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
validate_function_name = "validate_inputs"
|
||||
validate_function = first_real_override(obj_class, validate_function_name)
|
||||
else:
|
||||
validate_function_name = "VALIDATE_INPUTS"
|
||||
validate_function = getattr(obj_class, validate_function_name, None)
|
||||
if validate_function is not None:
|
||||
argspec = inspect.getfullargspec(validate_function)
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
||||
validate_function_inputs = argspec.args
|
||||
validate_has_kwargs = argspec.varkw is not None
|
||||
received_types = {}
|
||||
@@ -923,7 +846,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
@@ -931,7 +854,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
@@ -965,7 +889,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -989,8 +913,7 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
|
||||
return (False, error, [], {})
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
error = {
|
||||
@@ -1122,11 +1045,6 @@ class PromptQueue:
|
||||
if status is not None:
|
||||
status_dict = copy.deepcopy(status._asdict())
|
||||
|
||||
# Remove sensitive data from extra_data before storing in history
|
||||
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
|
||||
if sensitive_val in prompt[3]:
|
||||
prompt[3].pop(sensitive_val)
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
"outputs": {},
|
||||
@@ -1172,7 +1090,7 @@ class PromptQueue:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_history(self, prompt_id=None, max_items=None, offset=-1, map_function=None):
|
||||
def get_history(self, prompt_id=None, max_items=None, offset=-1):
|
||||
with self.mutex:
|
||||
if prompt_id is None:
|
||||
out = {}
|
||||
@@ -1181,21 +1099,13 @@ class PromptQueue:
|
||||
offset = len(self.history) - max_items
|
||||
for k in self.history:
|
||||
if i >= offset:
|
||||
p = self.history[k]
|
||||
if map_function is not None:
|
||||
p = map_function(p)
|
||||
out[k] = p
|
||||
out[k] = self.history[k]
|
||||
if max_items is not None and len(out) >= max_items:
|
||||
break
|
||||
i += 1
|
||||
return out
|
||||
elif prompt_id in self.history:
|
||||
p = self.history[prompt_id]
|
||||
if map_function is None:
|
||||
p = copy.deepcopy(p)
|
||||
else:
|
||||
p = map_function(p)
|
||||
return {prompt_id: p}
|
||||
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
31
main.py
31
main.py
@@ -13,7 +13,6 @@ import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
@@ -115,15 +114,6 @@ if os.name == "nt":
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.default_device is not None:
|
||||
default_dev = args.default_device
|
||||
devices = list(range(32))
|
||||
devices.remove(default_dev)
|
||||
devices.insert(0, default_dev)
|
||||
devices = ','.join(map(str, devices))
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
@@ -139,9 +129,6 @@ if __name__ == "__main__":
|
||||
|
||||
import cuda_malloc
|
||||
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
@@ -259,17 +246,9 @@ def hijack_progress(server_instance):
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if preview_image is not None:
|
||||
# Only send old method if client doesn't support preview metadata
|
||||
if not feature_flags.supports_feature(
|
||||
server_instance.sockets_metadata,
|
||||
server_instance.client_id,
|
||||
"supports_preview_metadata",
|
||||
):
|
||||
server_instance.send_sync(
|
||||
BinaryEventTypes.UNENCODED_PREVIEW_IMAGE,
|
||||
preview_image,
|
||||
server_instance.client_id,
|
||||
)
|
||||
# Also send old method for backward compatibility
|
||||
# TODO - Remove after this repo is updated to frontend with metadata support
|
||||
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
@@ -313,10 +292,10 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
))
|
||||
)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
74
nodes.py
74
nodes.py
@@ -1,12 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import hashlib
|
||||
import inspect
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
@@ -28,9 +26,6 @@ import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.controlnet
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||
from comfy_api.version_list import supported_versions
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@@ -2106,7 +2101,7 @@ def get_module_name(module_path: str) -> str:
|
||||
return base_path
|
||||
|
||||
|
||||
async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
module_name = get_module_name(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
sp = os.path.splitext(module_path)
|
||||
@@ -2154,7 +2149,6 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
if os.path.isdir(web_dir):
|
||||
EXTENSION_WEB_DIRS[module_name] = web_dir
|
||||
|
||||
# V1 node definition
|
||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||
if name not in ignore:
|
||||
@@ -2163,45 +2157,15 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
# V3 Extension Definition
|
||||
elif hasattr(module, "comfy_entrypoint"):
|
||||
entrypoint = getattr(module, "comfy_entrypoint")
|
||||
if not callable(entrypoint):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
|
||||
return False
|
||||
try:
|
||||
if inspect.iscoroutinefunction(entrypoint):
|
||||
extension = await entrypoint()
|
||||
else:
|
||||
extension = entrypoint()
|
||||
if not isinstance(extension, ComfyExtension):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||
return False
|
||||
node_list = await extension.get_node_list()
|
||||
if not isinstance(node_list, list):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||
return False
|
||||
for node_cls in node_list:
|
||||
node_cls: io.ComfyNode
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
if schema.node_id not in ignore:
|
||||
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||
if schema.display_name is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
def init_external_custom_nodes():
|
||||
"""
|
||||
Initializes the external custom nodes.
|
||||
|
||||
@@ -2227,7 +2191,7 @@ async def init_external_custom_nodes():
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
@@ -2240,7 +2204,7 @@ async def init_external_custom_nodes():
|
||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||
logging.info("")
|
||||
|
||||
async def init_builtin_extra_nodes():
|
||||
def init_builtin_extra_nodes():
|
||||
"""
|
||||
Initializes the built-in extra nodes in ComfyUI.
|
||||
|
||||
@@ -2319,18 +2283,18 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py",
|
||||
"nodes_tcfg.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
for node_file in extras_files:
|
||||
if not await load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
||||
|
||||
async def init_builtin_api_nodes():
|
||||
def init_builtin_api_nodes():
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_ideogram.py",
|
||||
@@ -2346,40 +2310,30 @@ async def init_builtin_api_nodes():
|
||||
"nodes_pika.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_tripo.py",
|
||||
"nodes_moonvalley.py",
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
]
|
||||
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
return api_nodes_files
|
||||
|
||||
import_failed = []
|
||||
for node_file in api_nodes_files:
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
||||
async def init_public_apis():
|
||||
register_versions([
|
||||
ComfyAPIWithVersion(
|
||||
version=getattr(v, "VERSION"),
|
||||
api_class=v
|
||||
) for v in supported_versions
|
||||
])
|
||||
|
||||
async def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
||||
await init_public_apis()
|
||||
|
||||
import_failed = await init_builtin_extra_nodes()
|
||||
def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
||||
import_failed = init_builtin_extra_nodes()
|
||||
|
||||
import_failed_api = []
|
||||
if init_api_nodes:
|
||||
import_failed_api = await init_builtin_api_nodes()
|
||||
import_failed_api = init_builtin_api_nodes()
|
||||
|
||||
if init_custom_nodes:
|
||||
await init_external_custom_nodes()
|
||||
init_external_custom_nodes()
|
||||
else:
|
||||
logging.info("Skipping loading of custom nodes")
|
||||
|
||||
|
||||
322
notebooks/comfyui_colab.ipynb
Normal file
322
notebooks/comfyui_colab.ipynb
Normal file
@@ -0,0 +1,322 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "aaaaaaaaaa"
|
||||
},
|
||||
"source": [
|
||||
"Git clone the repo and install the requirements. (ignore the pip errors about protobuf)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "bbbbbbbbbb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Environment Setup\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"OPTIONS = {}\n",
|
||||
"\n",
|
||||
"USE_GOOGLE_DRIVE = False #@param {type:\"boolean\"}\n",
|
||||
"UPDATE_COMFY_UI = True #@param {type:\"boolean\"}\n",
|
||||
"WORKSPACE = 'ComfyUI'\n",
|
||||
"OPTIONS['USE_GOOGLE_DRIVE'] = USE_GOOGLE_DRIVE\n",
|
||||
"OPTIONS['UPDATE_COMFY_UI'] = UPDATE_COMFY_UI\n",
|
||||
"\n",
|
||||
"if OPTIONS['USE_GOOGLE_DRIVE']:\n",
|
||||
" !echo \"Mounting Google Drive...\"\n",
|
||||
" %cd /\n",
|
||||
" \n",
|
||||
" from google.colab import drive\n",
|
||||
" drive.mount('/content/drive')\n",
|
||||
"\n",
|
||||
" WORKSPACE = \"/content/drive/MyDrive/ComfyUI\"\n",
|
||||
" %cd /content/drive/MyDrive\n",
|
||||
"\n",
|
||||
"![ ! -d $WORKSPACE ] && echo -= Initial setup ComfyUI =- && git clone https://github.com/comfyanonymous/ComfyUI\n",
|
||||
"%cd $WORKSPACE\n",
|
||||
"\n",
|
||||
"if OPTIONS['UPDATE_COMFY_UI']:\n",
|
||||
" !echo -= Updating ComfyUI =-\n",
|
||||
" !git pull\n",
|
||||
"\n",
|
||||
"!echo -= Install dependencies =-\n",
|
||||
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "cccccccccc"
|
||||
},
|
||||
"source": [
|
||||
"Download some models/checkpoints/vae or custom comfyui nodes (uncomment the commands for the ones you want)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dddddddddd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Checkpoints\n",
|
||||
"\n",
|
||||
"### SDXL\n",
|
||||
"### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n",
|
||||
"\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
|
||||
"\n",
|
||||
"# SDXL ReVision\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||
"\n",
|
||||
"# SD1.5\n",
|
||||
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
|
||||
"\n",
|
||||
"# SD2\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors -P ./models/checkpoints/\n",
|
||||
"\n",
|
||||
"# Some SD1.5 anime style\n",
|
||||
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_hard.safetensors -P ./models/checkpoints/\n",
|
||||
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1_orangemixs.safetensors -P ./models/checkpoints/\n",
|
||||
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A3_orangemixs.safetensors -P ./models/checkpoints/\n",
|
||||
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
|
||||
"\n",
|
||||
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
|
||||
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# unCLIP models\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/illuminatiDiffusionV1_v11_unCLIP/resolve/main/illuminatiDiffusionV1_v11-unclip-h-fp16.safetensors -P ./models/checkpoints/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/wd-1.5-beta2_unCLIP/resolve/main/wd-1-5-beta2-aesthetic-unclip-h-fp16.safetensors -P ./models/checkpoints/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# VAE\n",
|
||||
"!wget -c https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors -P ./models/vae/\n",
|
||||
"#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt -P ./models/vae/\n",
|
||||
"#!wget -c https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt -P ./models/vae/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Loras\n",
|
||||
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
|
||||
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# T2I-Adapter\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_openpose_sd14v1.pth -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/controlnet/\n",
|
||||
"\n",
|
||||
"# T2I Styles Model\n",
|
||||
"#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth -P ./models/style_models/\n",
|
||||
"\n",
|
||||
"# CLIPVision model (needed for styles model)\n",
|
||||
"#!wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin -O ./models/clip_vision/clip_vit14.bin\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# ControlNet\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n",
|
||||
"\n",
|
||||
"# ControlNet SDXL\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors -P ./models/controlnet/\n",
|
||||
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors -P ./models/controlnet/\n",
|
||||
"\n",
|
||||
"# Controlnet Preprocessor nodes by Fannovel16\n",
|
||||
"#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# GLIGEN\n",
|
||||
"#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# ESRGAN upscale model\n",
|
||||
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kkkkkkkkkkkkkkk"
|
||||
},
|
||||
"source": [
|
||||
"### Run ComfyUI with cloudflared (Recommended Way)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "jjjjjjjjjjjjjj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
|
||||
"!dpkg -i cloudflared-linux-amd64.deb\n",
|
||||
"\n",
|
||||
"import subprocess\n",
|
||||
"import threading\n",
|
||||
"import time\n",
|
||||
"import socket\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n",
|
||||
"\n",
|
||||
" p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
||||
" for line in p.stderr:\n",
|
||||
" l = line.decode()\n",
|
||||
" if \"trycloudflare.com \" in l:\n",
|
||||
" print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n",
|
||||
" #print(l, end='')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
||||
"\n",
|
||||
"!python main.py --dont-print-server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kkkkkkkkkkkkkk"
|
||||
},
|
||||
"source": [
|
||||
"### Run ComfyUI with localtunnel\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "jjjjjjjjjjjjj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!npm install -g localtunnel\n",
|
||||
"\n",
|
||||
"import threading\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
||||
"\n",
|
||||
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||
" for line in p.stdout:\n",
|
||||
" print(line.decode(), end='')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
||||
"\n",
|
||||
"!python main.py --dont-print-server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gggggggggg"
|
||||
},
|
||||
"source": [
|
||||
"### Run ComfyUI with colab iframe (use only in case the previous way with localtunnel doesn't work)\n",
|
||||
"\n",
|
||||
"You should see the ui appear in an iframe. If you get a 403 error, it's your firefox settings or an extension that's messing things up.\n",
|
||||
"\n",
|
||||
"If you want to open it in another window use the link.\n",
|
||||
"\n",
|
||||
"Note that some UI features like live image previews won't work because the colab iframe blocks websockets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hhhhhhhhhh"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import threading\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||
" result = sock.connect_ex(('127.0.0.1', port))\n",
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" from google.colab import output\n",
|
||||
" output.serve_kernel_port_as_iframe(port, height=1024)\n",
|
||||
" print(\"to open it in a window you can open this link here:\")\n",
|
||||
" output.serve_kernel_port_as_window(port)\n",
|
||||
"\n",
|
||||
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
|
||||
"\n",
|
||||
"!python main.py --dont-print-server"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.48"
|
||||
version = "0.3.43"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
@@ -21,4 +21,4 @@ lint.select = [
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||
exclude = ["*.ipynb"]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.45
|
||||
comfyui-embedded-docs==0.2.4
|
||||
comfyui-workflow-templates==0.1.31
|
||||
comfyui-embedded-docs==0.2.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
@@ -10,11 +10,11 @@ import urllib.parse
|
||||
server_address = "127.0.0.1:8188"
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
def queue_prompt(prompt, prompt_id):
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": prompt_id}
|
||||
def queue_prompt(prompt):
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||||
urllib.request.urlopen(req).read()
|
||||
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
def get_image(filename, subfolder, folder_type):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
@@ -27,8 +27,7 @@ def get_history(prompt_id):
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_images(ws, prompt):
|
||||
prompt_id = str(uuid.uuid4())
|
||||
queue_prompt(prompt, prompt_id)
|
||||
prompt_id = queue_prompt(prompt)['prompt_id']
|
||||
output_images = {}
|
||||
while True:
|
||||
out = ws.recv()
|
||||
|
||||
61
server.py
61
server.py
@@ -26,11 +26,9 @@ import mimetypes
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_api import feature_flags
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@@ -176,7 +174,6 @@ class PromptServer():
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
self.sockets_metadata = dict()
|
||||
self.web_root = (
|
||||
FrontendManager.init_frontend(args.front_end_version)
|
||||
if args.front_end_root is None
|
||||
@@ -201,53 +198,20 @@ class PromptServer():
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
# Store WebSocket for backward compatibility
|
||||
self.sockets[sid] = ws
|
||||
# Store metadata separately
|
||||
self.sockets_metadata[sid] = {"feature_flags": {}}
|
||||
|
||||
try:
|
||||
# Send initial state to the new client
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logging.warning('ws connection closed with exception %s' % ws.exception())
|
||||
elif msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
# Check if first message is feature flags
|
||||
if first_message and data.get("type") == "feature_flags":
|
||||
# Store client feature flags
|
||||
client_flags = data.get("data", {})
|
||||
self.sockets_metadata[sid]["feature_flags"] = client_flags
|
||||
|
||||
# Send server feature flags in response
|
||||
await self.send(
|
||||
"feature_flags",
|
||||
feature_flags.get_server_features(),
|
||||
sid,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Feature flags negotiated for client {sid}: {client_flags}"
|
||||
)
|
||||
first_message = False
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(
|
||||
f"Invalid JSON received from client {sid}: {msg.data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing WebSocket message: {e}")
|
||||
finally:
|
||||
self.sockets.pop(sid, None)
|
||||
self.sockets_metadata.pop(sid, None)
|
||||
return ws
|
||||
|
||||
@routes.get("/")
|
||||
@@ -554,7 +518,6 @@ class PromptServer():
|
||||
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
required_frontend_version = FrontendManager.get_required_frontend_version()
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
@@ -562,7 +525,6 @@ class PromptServer():
|
||||
"ram_total": ram_total,
|
||||
"ram_free": ram_free,
|
||||
"comfyui_version": __version__,
|
||||
"required_frontend_version": required_frontend_version,
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
@@ -582,18 +544,12 @@ class PromptServer():
|
||||
}
|
||||
return web.json_response(system_stats)
|
||||
|
||||
@routes.get("/features")
|
||||
async def get_features(request):
|
||||
return web.json_response(feature_flags.get_server_features())
|
||||
|
||||
@routes.get("/prompt")
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
return obj_class.GET_NODE_INFO_V1()
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
@@ -683,13 +639,8 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||
|
||||
partial_execution_targets = None
|
||||
if "partial_execution_targets" in json_data:
|
||||
partial_execution_targets = json_data["partial_execution_targets"]
|
||||
|
||||
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||
prompt_id = str(uuid.uuid4())
|
||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
@@ -931,10 +882,10 @@ class PromptServer():
|
||||
ssl_ctx = None
|
||||
scheme = "http"
|
||||
if args.tls_keyfile and args.tls_certfile:
|
||||
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
|
||||
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
|
||||
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
|
||||
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
|
||||
keyfile=args.tls_keyfile)
|
||||
scheme = "https"
|
||||
scheme = "https"
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
from unittest.mock import patch, mock_open
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.frontend_management import (
|
||||
FrontendManager,
|
||||
@@ -172,36 +172,3 @@ def test_init_frontend_fallback_on_error():
|
||||
# Assert
|
||||
assert frontend_path == "/default/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
|
||||
def test_get_frontend_version():
|
||||
# Arrange
|
||||
expected_version = "1.25.0"
|
||||
mock_requirements_content = """torch
|
||||
torchsde
|
||||
comfyui-frontend-package==1.25.0
|
||||
other-package==1.0.0
|
||||
numpy"""
|
||||
|
||||
# Act
|
||||
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||
version = FrontendManager.get_required_frontend_version()
|
||||
|
||||
# Assert
|
||||
assert version == expected_version
|
||||
|
||||
|
||||
def test_get_frontend_version_invalid_semver():
|
||||
# Arrange
|
||||
mock_requirements_content = """torch
|
||||
torchsde
|
||||
comfyui-frontend-package==1.29.3.75
|
||||
other-package==1.0.0
|
||||
numpy"""
|
||||
|
||||
# Act
|
||||
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||
version = FrontendManager.get_required_frontend_version()
|
||||
|
||||
# Assert
|
||||
assert version is None
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Tests for feature flags functionality."""
|
||||
|
||||
from comfy_api.feature_flags import (
|
||||
get_connection_feature,
|
||||
supports_feature,
|
||||
get_server_features,
|
||||
SERVER_FEATURE_FLAGS,
|
||||
)
|
||||
|
||||
|
||||
class TestFeatureFlags:
|
||||
"""Test suite for feature flags functions."""
|
||||
|
||||
def test_get_server_features_returns_copy(self):
|
||||
"""Test that get_server_features returns a copy of the server flags."""
|
||||
features = get_server_features()
|
||||
# Verify it's a copy by modifying it
|
||||
features["test_flag"] = True
|
||||
# Original should be unchanged
|
||||
assert "test_flag" not in SERVER_FEATURE_FLAGS
|
||||
|
||||
def test_get_server_features_contains_expected_flags(self):
|
||||
"""Test that server features contain expected flags."""
|
||||
features = get_server_features()
|
||||
assert "supports_preview_metadata" in features
|
||||
assert features["supports_preview_metadata"] is True
|
||||
assert "max_upload_size" in features
|
||||
assert isinstance(features["max_upload_size"], (int, float))
|
||||
|
||||
def test_get_connection_feature_with_missing_sid(self):
|
||||
"""Test getting feature for non-existent session ID."""
|
||||
sockets_metadata = {}
|
||||
result = get_connection_feature(sockets_metadata, "missing_sid", "some_feature")
|
||||
assert result is False # Default value
|
||||
|
||||
def test_get_connection_feature_with_custom_default(self):
|
||||
"""Test getting feature with custom default value."""
|
||||
sockets_metadata = {}
|
||||
result = get_connection_feature(
|
||||
sockets_metadata, "missing_sid", "some_feature", default="custom_default"
|
||||
)
|
||||
assert result == "custom_default"
|
||||
|
||||
def test_get_connection_feature_with_feature_flags(self):
|
||||
"""Test getting feature from connection with feature flags."""
|
||||
sockets_metadata = {
|
||||
"sid1": {
|
||||
"feature_flags": {
|
||||
"supports_preview_metadata": True,
|
||||
"custom_feature": "value",
|
||||
},
|
||||
}
|
||||
}
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "supports_preview_metadata")
|
||||
assert result is True
|
||||
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "custom_feature")
|
||||
assert result == "value"
|
||||
|
||||
def test_get_connection_feature_missing_feature(self):
|
||||
"""Test getting non-existent feature from connection."""
|
||||
sockets_metadata = {
|
||||
"sid1": {"feature_flags": {"existing_feature": True}}
|
||||
}
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "missing_feature")
|
||||
assert result is False
|
||||
|
||||
def test_supports_feature_returns_boolean(self):
|
||||
"""Test that supports_feature always returns boolean."""
|
||||
sockets_metadata = {
|
||||
"sid1": {
|
||||
"feature_flags": {
|
||||
"bool_feature": True,
|
||||
"string_feature": "value",
|
||||
"none_feature": None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# True boolean feature
|
||||
assert supports_feature(sockets_metadata, "sid1", "bool_feature") is True
|
||||
|
||||
# Non-boolean values should return False
|
||||
assert supports_feature(sockets_metadata, "sid1", "string_feature") is False
|
||||
assert supports_feature(sockets_metadata, "sid1", "none_feature") is False
|
||||
assert supports_feature(sockets_metadata, "sid1", "missing_feature") is False
|
||||
|
||||
def test_supports_feature_with_missing_connection(self):
|
||||
"""Test supports_feature with missing connection."""
|
||||
sockets_metadata = {}
|
||||
assert supports_feature(sockets_metadata, "missing_sid", "any_feature") is False
|
||||
|
||||
def test_empty_feature_flags_dict(self):
|
||||
"""Test connection with empty feature flags dictionary."""
|
||||
sockets_metadata = {"sid1": {"feature_flags": {}}}
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "any_feature")
|
||||
assert result is False
|
||||
assert supports_feature(sockets_metadata, "sid1", "any_feature") is False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user